From b218b7a3f2b6b2d1493d30e84cf92bb88a4af324 Mon Sep 17 00:00:00 2001 From: kewh5868 Date: Mon, 30 Mar 2026 11:31:05 -0600 Subject: [PATCH 1/3] Add ClusterDynamics and ClusterDynamics ML workflows --- src/saxshell/clusterdynamics/__init__.py | 31 + src/saxshell/clusterdynamics/__main__.py | 6 + src/saxshell/clusterdynamics/cli.py | 39 + src/saxshell/clusterdynamics/dataset.py | 496 ++ src/saxshell/clusterdynamics/report.py | 1097 +++++ src/saxshell/clusterdynamics/ui/__init__.py | 5 + src/saxshell/clusterdynamics/ui/__main__.py | 6 + .../clusterdynamics/ui/main_window.py | 1849 +++++++ src/saxshell/clusterdynamics/ui/plot_panel.py | 311 ++ src/saxshell/clusterdynamics/workflow.py | 1233 +++++ src/saxshell/clusterdynamicsml/__init__.py | 33 + src/saxshell/clusterdynamicsml/__main__.py | 4 + src/saxshell/clusterdynamicsml/cli.py | 5 + src/saxshell/clusterdynamicsml/dataset.py | 1273 +++++ src/saxshell/clusterdynamicsml/ui/__init__.py | 6 + .../clusterdynamicsml/ui/main_window.py | 2602 ++++++++++ .../clusterdynamicsml/ui/plot_panel.py | 1128 +++++ src/saxshell/clusterdynamicsml/workflow.py | 4353 +++++++++++++++++ tests/test_clusterdynamics.py | 468 ++ tests/test_clusterdynamicsml.py | 1974 ++++++++ 20 files changed, 16919 insertions(+) create mode 100644 src/saxshell/clusterdynamics/__init__.py create mode 100644 src/saxshell/clusterdynamics/__main__.py create mode 100644 src/saxshell/clusterdynamics/cli.py create mode 100644 src/saxshell/clusterdynamics/dataset.py create mode 100644 src/saxshell/clusterdynamics/report.py create mode 100644 src/saxshell/clusterdynamics/ui/__init__.py create mode 100644 src/saxshell/clusterdynamics/ui/__main__.py create mode 100644 src/saxshell/clusterdynamics/ui/main_window.py create mode 100644 src/saxshell/clusterdynamics/ui/plot_panel.py create mode 100644 src/saxshell/clusterdynamics/workflow.py create mode 100644 src/saxshell/clusterdynamicsml/__init__.py create mode 100644 src/saxshell/clusterdynamicsml/__main__.py create mode 100644 src/saxshell/clusterdynamicsml/cli.py create mode 100644 src/saxshell/clusterdynamicsml/dataset.py create mode 100644 src/saxshell/clusterdynamicsml/ui/__init__.py create mode 100644 src/saxshell/clusterdynamicsml/ui/main_window.py create mode 100644 src/saxshell/clusterdynamicsml/ui/plot_panel.py create mode 100644 src/saxshell/clusterdynamicsml/workflow.py create mode 100644 tests/test_clusterdynamics.py create mode 100644 tests/test_clusterdynamicsml.py diff --git a/src/saxshell/clusterdynamics/__init__.py b/src/saxshell/clusterdynamics/__init__.py new file mode 100644 index 0000000..1ca2638 --- /dev/null +++ b/src/saxshell/clusterdynamics/__init__.py @@ -0,0 +1,31 @@ +"""Time-binned cluster-distribution analysis tools.""" + +from .dataset import ( + LoadedClusterDynamicsDataset, + SavedClusterDynamicsDataset, + export_cluster_dynamics_colormap_csv, + export_cluster_dynamics_lifetime_csv, + load_cluster_dynamics_dataset, + save_cluster_dynamics_dataset, +) +from .workflow import ( + ClusterDynamicsResult, + ClusterDynamicsSelectionPreview, + ClusterDynamicsWorkflow, + ClusterLifetimeSummary, + ClusterSizeLifetimeSummary, +) + +__all__ = [ + "ClusterDynamicsResult", + "ClusterDynamicsSelectionPreview", + "ClusterDynamicsWorkflow", + "ClusterLifetimeSummary", + "ClusterSizeLifetimeSummary", + "LoadedClusterDynamicsDataset", + "SavedClusterDynamicsDataset", + "export_cluster_dynamics_colormap_csv", + "export_cluster_dynamics_lifetime_csv", + "load_cluster_dynamics_dataset", + "save_cluster_dynamics_dataset", +] diff --git a/src/saxshell/clusterdynamics/__main__.py b/src/saxshell/clusterdynamics/__main__.py new file mode 100644 index 0000000..855590e --- /dev/null +++ b/src/saxshell/clusterdynamics/__main__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from .cli import main + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/saxshell/clusterdynamics/cli.py b/src/saxshell/clusterdynamics/cli.py new file mode 100644 index 0000000..b60a0a9 --- /dev/null +++ b/src/saxshell/clusterdynamics/cli.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import argparse + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + prog="clusterdynamics", + description=( + "Analyze time-binned cluster distributions from extracted PDB " + "or XYZ frame folders, or launch the Qt UI. Running without " + "additional arguments launches the UI." + ), + ) + parser.add_argument( + "frames_dir", + nargs="?", + help="Optional extracted frames directory to prefill in the UI.", + ) + parser.add_argument( + "--energy-file", + help="Optional CP2K .ener file to prefill in the UI.", + ) + parser.add_argument( + "--project-dir", + help="Optional SAXSShell project directory to prefill in the UI.", + ) + args = parser.parse_args(argv) + + from .ui.main_window import launch_clusterdynamics_ui + + return launch_clusterdynamics_ui( + getattr(args, "frames_dir", None), + energy_file=getattr(args, "energy_file", None), + project_dir=getattr(args, "project_dir", None), + ) + + +__all__ = ["main"] diff --git a/src/saxshell/clusterdynamics/dataset.py b/src/saxshell/clusterdynamics/dataset.py new file mode 100644 index 0000000..140b582 --- /dev/null +++ b/src/saxshell/clusterdynamics/dataset.py @@ -0,0 +1,496 @@ +from __future__ import annotations + +import csv +import json +from dataclasses import asdict, dataclass +from pathlib import Path + +import numpy as np + +from saxshell.cluster import FrameClusterResult +from saxshell.mdtrajectory.frame.cp2k_ener import CP2KEnergyData + +from .workflow import ( + ClusterDynamicsResult, + ClusterDynamicsSelectionPreview, + ClusterLifetimeSummary, + ClusterSizeLifetimeSummary, + DisplayMode, + TimeUnit, + _resolve_colormap_timestep_settings, +) + +DATASET_VERSION = 1 + + +@dataclass(slots=True) +class SavedClusterDynamicsDataset: + dataset_file: Path + written_files: tuple[Path, ...] + + +@dataclass(slots=True) +class LoadedClusterDynamicsDataset: + dataset_file: Path + result: ClusterDynamicsResult + analysis_settings: dict[str, object] + + +def save_cluster_dynamics_dataset( + result: ClusterDynamicsResult, + output_file: str | Path, + *, + analysis_settings: dict[str, object] | None = None, +) -> SavedClusterDynamicsDataset: + dataset_file = Path(output_file).expanduser().resolve() + if dataset_file.suffix.lower() != ".json": + dataset_file = dataset_file.with_suffix(".json") + dataset_file.parent.mkdir(parents=True, exist_ok=True) + + payload = { + "version": DATASET_VERSION, + "analysis_settings": analysis_settings or {}, + "preview_summary": dict(result.preview.summary), + "preview": result.preview.to_dict(), + "selected_frame_indices": list(result.preview.selected_frame_indices), + "selected_frame_names": list(result.preview.selected_frame_names), + "selected_source_frame_indices": list( + result.preview.selected_source_frame_indices + ), + "frame_results": [ + { + "frame_index": int(frame_result.frame_index), + "time_fs": ( + None + if frame_result.time_fs is None + else float(frame_result.time_fs) + ), + } + for frame_result in result.frame_results + ], + "cluster_labels": list(result.cluster_labels), + "cluster_sizes": { + str(label): int(size) + for label, size in result.cluster_sizes.items() + }, + "bin_edges_fs": result.bin_edges_fs.tolist(), + "frames_per_bin": result.frames_per_bin.tolist(), + "total_clusters_per_bin": result.total_clusters_per_bin.tolist(), + "raw_count_matrix": result.raw_count_matrix.tolist(), + "fraction_matrix": result.fraction_matrix.tolist(), + "mean_count_matrix": result.mean_count_matrix.tolist(), + "frame_count_matrix": result.frame_count_matrix.tolist(), + "total_clusters_per_frame": result.total_clusters_per_frame.tolist(), + "lifetime_by_label": [ + asdict(entry) for entry in result.lifetime_by_label + ], + "lifetime_by_size": [ + asdict(entry) for entry in result.lifetime_by_size + ], + "energy_data": _serialize_energy_data(result.energy_data), + } + dataset_file.write_text( + json.dumps(payload, indent=2) + "\n", + encoding="utf-8", + ) + + written_files = [dataset_file] + written_files.append( + export_cluster_dynamics_colormap_csv( + result, + dataset_file.with_name( + f"{dataset_file.stem}_cluster_distribution.csv" + ), + ) + ) + written_files.append( + export_cluster_dynamics_lifetime_csv( + result, + dataset_file.with_name(f"{dataset_file.stem}_lifetime.csv"), + ) + ) + if result.energy_data is not None: + written_files.append( + _write_energy_csv(result.energy_data, dataset_file) + ) + + return SavedClusterDynamicsDataset( + dataset_file=dataset_file, + written_files=tuple(written_files), + ) + + +def load_cluster_dynamics_dataset( + dataset_file: str | Path, +) -> LoadedClusterDynamicsDataset: + resolved_file = Path(dataset_file).expanduser().resolve() + payload = json.loads(resolved_file.read_text(encoding="utf-8")) + if int(payload.get("version", 0)) != DATASET_VERSION: + raise ValueError( + "This cluster-dynamics dataset uses an unsupported format version." + ) + + preview_payload = dict(payload.get("preview", {})) + preview_summary = dict(payload.get("preview_summary", {})) + frame_results_payload = payload.get("frame_results", []) + frame_results = tuple( + FrameClusterResult( + frame_index=int(entry["frame_index"]), + time_fs=( + None + if entry.get("time_fs") is None + else float(entry["time_fs"]) + ), + clusters=[], + ) + for entry in frame_results_payload + if isinstance(entry, dict) and "frame_index" in entry + ) + + frame_timestep_fs = float(preview_payload.get("frame_timestep_fs", 0.5)) + ( + frames_per_colormap_timestep, + colormap_timestep_fs, + ) = _resolve_colormap_timestep_settings( + frame_timestep_fs=frame_timestep_fs, + frames_per_colormap_timestep=preview_payload.get( + "frames_per_colormap_timestep" + ), + colormap_timestep_fs=_optional_float( + preview_payload.get("colormap_timestep_fs") + ), + legacy_bin_size_fs=_optional_float(preview_payload.get("bin_size_fs")), + require_integral_ratio=False, + ) + + preview = ClusterDynamicsSelectionPreview( + summary=preview_summary, + frame_format=str(preview_payload.get("frame_format", "xyz")), + resolved_box_dimensions=_coerce_box_dimensions( + preview_payload.get("resolved_box_dimensions") + ), + use_pbc=bool(preview_payload.get("use_pbc", False)), + first_frame_time_fs=float( + preview_payload.get("first_frame_time_fs", 0.0) + ), + frame_timestep_fs=frame_timestep_fs, + frames_per_colormap_timestep=frames_per_colormap_timestep, + colormap_timestep_fs=colormap_timestep_fs, + analysis_start_fs=float(preview_payload.get("analysis_start_fs", 0.0)), + analysis_stop_fs=float(preview_payload.get("analysis_stop_fs", 0.0)), + first_selected_time_fs=_optional_float( + preview_payload.get("first_selected_time_fs") + ), + last_selected_time_fs=_optional_float( + preview_payload.get("last_selected_time_fs") + ), + selected_frame_indices=tuple( + int(value) for value in payload.get("selected_frame_indices", []) + ), + selected_frame_names=tuple( + str(value) for value in payload.get("selected_frame_names", []) + ), + selected_source_frame_indices=tuple( + None if value is None else int(value) + for value in payload.get("selected_source_frame_indices", []) + ), + energy_file=( + None + if preview_payload.get("energy_file") is None + else Path(str(preview_payload.get("energy_file"))) + ), + folder_start_time_fs=_optional_float( + preview_payload.get("folder_start_time_fs") + ), + folder_start_time_source=_optional_str( + preview_payload.get("folder_start_time_source") + ), + time_source_label=str( + preview_payload.get("time_source_label", "Saved dataset") + ), + time_warnings=tuple( + str(value) for value in preview_payload.get("time_warnings", []) + ), + ) + + result = ClusterDynamicsResult( + preview=preview, + frame_results=frame_results, + bin_edges_fs=np.asarray(payload.get("bin_edges_fs", []), dtype=float), + frames_per_bin=np.asarray( + payload.get("frames_per_bin", []), + dtype=float, + ), + total_clusters_per_bin=np.asarray( + payload.get("total_clusters_per_bin", []), + dtype=float, + ), + cluster_labels=tuple( + str(value) for value in payload.get("cluster_labels", []) + ), + cluster_sizes={ + str(label): int(size) + for label, size in dict(payload.get("cluster_sizes", {})).items() + }, + raw_count_matrix=np.asarray( + payload.get("raw_count_matrix", []), + dtype=float, + ), + fraction_matrix=np.asarray( + payload.get("fraction_matrix", []), + dtype=float, + ), + mean_count_matrix=np.asarray( + payload.get("mean_count_matrix", []), + dtype=float, + ), + frame_count_matrix=np.asarray( + payload.get("frame_count_matrix", []), + dtype=float, + ), + total_clusters_per_frame=np.asarray( + payload.get("total_clusters_per_frame", []), + dtype=float, + ), + lifetime_by_label=tuple( + ClusterLifetimeSummary(**_normalize_summary_payload(entry)) + for entry in payload.get("lifetime_by_label", []) + if isinstance(entry, dict) + ), + lifetime_by_size=tuple( + ClusterSizeLifetimeSummary(**_normalize_summary_payload(entry)) + for entry in payload.get("lifetime_by_size", []) + if isinstance(entry, dict) + ), + energy_data=_deserialize_energy_data( + payload.get("energy_data"), + fallback_path=resolved_file, + ), + ) + return LoadedClusterDynamicsDataset( + dataset_file=resolved_file, + result=result, + analysis_settings=dict(payload.get("analysis_settings", {})), + ) + + +def export_cluster_dynamics_colormap_csv( + result: ClusterDynamicsResult, + output_file: str | Path, + *, + display_mode: DisplayMode = "fraction", + time_unit: TimeUnit = "fs", +) -> Path: + output_path = _resolve_csv_output_path(output_file) + time_edges = result.time_edges(time_unit) + time_centers = result.bin_centers(time_unit) + displayed_matrix = result.matrix(display_mode) + with output_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.writer(handle) + writer.writerow( + [ + "label", + "cluster_size", + "bin_index", + "bin_start_fs", + "bin_stop_fs", + "bin_center_fs", + "bin_start_time", + "bin_stop_time", + "bin_center_time", + "time_unit", + "display_mode", + "colormap_value", + "raw_count", + "fraction", + "mean_count_per_frame", + "frames_in_bin", + "total_clusters_in_bin", + ] + ) + for row_index, label in enumerate(result.cluster_labels): + cluster_size = int(result.cluster_sizes.get(label, 0)) + for bin_index in range(result.bin_count): + writer.writerow( + [ + label, + cluster_size, + bin_index, + float(result.bin_edges_fs[bin_index]), + float(result.bin_edges_fs[bin_index + 1]), + float(result.bin_centers_fs[bin_index]), + float(time_edges[bin_index]), + float(time_edges[bin_index + 1]), + float(time_centers[bin_index]), + time_unit, + display_mode, + float(displayed_matrix[row_index, bin_index]), + float(result.raw_count_matrix[row_index, bin_index]), + float(result.fraction_matrix[row_index, bin_index]), + float(result.mean_count_matrix[row_index, bin_index]), + float(result.frames_per_bin[bin_index]), + float(result.total_clusters_per_bin[bin_index]), + ] + ) + return output_path + + +def export_cluster_dynamics_lifetime_csv( + result: ClusterDynamicsResult, + output_file: str | Path, +) -> Path: + output_path = _resolve_csv_output_path(output_file) + with output_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.writer(handle) + writer.writerow( + [ + "label", + "cluster_size", + "mean_lifetime_fs", + "std_lifetime_fs", + "completed_lifetime_count", + "window_truncated_lifetime_count", + "association_rate_per_ps", + "dissociation_rate_per_ps", + "occupancy_fraction", + "mean_count_per_frame", + "total_observations", + "occupied_frames", + "association_events", + "dissociation_events", + ] + ) + for entry in result.lifetime_by_label: + writer.writerow( + [ + entry.label, + int(entry.cluster_size), + _csv_float(entry.mean_lifetime_fs), + _csv_float(entry.std_lifetime_fs), + int(entry.completed_lifetime_count), + int(entry.window_truncated_lifetime_count), + float(entry.association_rate_per_ps), + float(entry.dissociation_rate_per_ps), + float(entry.occupancy_fraction), + float(entry.mean_count_per_frame), + int(entry.total_observations), + int(entry.occupied_frames), + int(entry.association_events), + int(entry.dissociation_events), + ] + ) + return output_path + + +def _write_energy_csv(energy_data: CP2KEnergyData, dataset_file: Path) -> Path: + output_path = dataset_file.with_name(f"{dataset_file.stem}_energy.csv") + with output_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.writer(handle) + writer.writerow( + ["step", "time_fs", "kinetic", "temperature", "potential"] + ) + for index in range(len(energy_data.time_fs)): + writer.writerow( + [ + float(energy_data.step[index]), + float(energy_data.time_fs[index]), + float(energy_data.kinetic[index]), + float(energy_data.temperature[index]), + float(energy_data.potential[index]), + ] + ) + return output_path + + +def _serialize_energy_data( + energy_data: CP2KEnergyData | None, +) -> dict[str, object] | None: + if energy_data is None: + return None + return { + "filepath": str(energy_data.filepath), + "step": energy_data.step.tolist(), + "time_fs": energy_data.time_fs.tolist(), + "kinetic": energy_data.kinetic.tolist(), + "temperature": energy_data.temperature.tolist(), + "potential": energy_data.potential.tolist(), + } + + +def _deserialize_energy_data( + payload: object, + *, + fallback_path: Path, +) -> CP2KEnergyData | None: + if not isinstance(payload, dict): + return None + filepath_value = payload.get("filepath") + filepath = ( + fallback_path if filepath_value is None else Path(str(filepath_value)) + ) + return CP2KEnergyData( + filepath=filepath, + step=np.asarray(payload.get("step", []), dtype=float), + time_fs=np.asarray(payload.get("time_fs", []), dtype=float), + kinetic=np.asarray(payload.get("kinetic", []), dtype=float), + temperature=np.asarray(payload.get("temperature", []), dtype=float), + potential=np.asarray(payload.get("potential", []), dtype=float), + ) + + +def _normalize_summary_payload( + payload: dict[str, object] +) -> dict[str, object]: + normalized = dict(payload) + if ( + "censored_lifetime_count" in normalized + and "window_truncated_lifetime_count" not in normalized + ): + normalized["window_truncated_lifetime_count"] = normalized.pop( + "censored_lifetime_count" + ) + return normalized + + +def _coerce_box_dimensions( + value: object, +) -> tuple[float, float, float] | None: + if value is None: + return None + components = tuple(float(component) for component in value) + if len(components) != 3: + raise ValueError("Saved box dimensions must contain three values.") + return components + + +def _optional_float(value: object) -> float | None: + return None if value is None else float(value) + + +def _optional_str(value: object) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None + + +def _csv_float(value: float | None) -> str: + return "" if value is None else f"{float(value):.12g}" + + +def _resolve_csv_output_path(output_file: str | Path) -> Path: + output_path = Path(output_file).expanduser().resolve() + if output_path.suffix.lower() != ".csv": + output_path = output_path.with_suffix(".csv") + output_path.parent.mkdir(parents=True, exist_ok=True) + return output_path + + +__all__ = [ + "LoadedClusterDynamicsDataset", + "SavedClusterDynamicsDataset", + "export_cluster_dynamics_colormap_csv", + "export_cluster_dynamics_lifetime_csv", + "load_cluster_dynamics_dataset", + "save_cluster_dynamics_dataset", +] diff --git a/src/saxshell/clusterdynamics/report.py b/src/saxshell/clusterdynamics/report.py new file mode 100644 index 0000000..83311a3 --- /dev/null +++ b/src/saxshell/clusterdynamics/report.py @@ -0,0 +1,1097 @@ +from __future__ import annotations + +import textwrap +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING + +from matplotlib.figure import Figure +from matplotlib.text import Text + +from saxshell.clusterdynamics import ClusterDynamicsResult +from saxshell.saxs.project_manager import ( + PowerPointExportSettings, + build_project_paths, +) + +if TYPE_CHECKING: + from saxshell.clusterdynamicsml.workflow import ClusterDynamicsMLResult + +ClusterReportProgressCallback = Callable[[int, int, str], None] + +_SLIDE_WIDTH_INCHES = 13.333 +_SLIDE_HEIGHT_INCHES = 7.5 +_SLIDE_LEFT_INCHES = 0.45 +_SLIDE_CONTENT_WIDTH_INCHES = 12.43 +_TITLE_TOP_INCHES = 0.28 +_TITLE_HEIGHT_INCHES = 0.42 +_SUBTITLE_TOP_INCHES = 0.72 +_SUBTITLE_HEIGHT_INCHES = 0.22 +_TEXT_TOP_INCHES = 1.10 +_TEXT_HEIGHT_INCHES = 5.92 +_FIGURE_TOP_INCHES = 1.08 +_FIGURE_HEIGHT_INCHES = 5.96 +_TABLE_TOP_INCHES = 1.08 +_TABLE_HEIGHT_INCHES = 5.98 +_THICK_RULE_HEIGHT_INCHES = 0.04 + + +@dataclass(slots=True) +class ClusterPowerPointExportResult: + report_path: Path + appended_to_existing: bool + added_slide_count: int + + +@dataclass(slots=True) +class _TextSection: + title: str + subtitle: str | None + pages: list[list[str]] + placeholder: str + font_size: float = 13.0 + + +@dataclass(slots=True) +class _FigureSection: + title: str + subtitle: str | None + figure: Figure | None + placeholder: str + + +@dataclass(slots=True) +class _TableSection: + title: str + subtitle: str | None + columns: tuple[str, ...] + rows: list[list[str]] + rows_per_slide: int + column_width_weights: tuple[float, ...] | None = None + alignments: tuple[str, ...] | None = None + note: str | None = None + header_font_size: float = 11.0 + row_font_size: float = 10.0 + + +class _ReportProgressTracker: + def __init__( + self, + total_steps: int, + callback: ClusterReportProgressCallback | None, + *, + opening_message: str, + ) -> None: + self.total_steps = max(int(total_steps), 1) + self._callback = callback + self._processed = 0 + self.emit(opening_message) + + def emit(self, message: str) -> None: + if self._callback is None: + return + self._callback(self._processed, self.total_steps, str(message)) + + def advance(self, message: str) -> None: + self._processed = min(self._processed + 1, self.total_steps) + self.emit(message) + + +def latest_project_powerpoint_report(project_dir: str | Path) -> Path | None: + reports_dir = build_project_paths(project_dir).reports_dir + reports_dir.mkdir(parents=True, exist_ok=True) + candidates = sorted( + reports_dir.glob("*.pptx"), + key=lambda path: (path.stat().st_mtime, path.name.lower()), + ) + if not candidates: + return None + return candidates[-1] + + +def default_powerpoint_report_path( + *, + project_dir: str | Path | None, + fallback_dir: str | Path, + fallback_stem: str, +) -> Path: + if project_dir is not None: + project_dir_path = Path(project_dir).expanduser().resolve() + existing = latest_project_powerpoint_report(project_dir_path) + if existing is not None: + return existing + reports_dir = build_project_paths(project_dir_path).reports_dir + reports_dir.mkdir(parents=True, exist_ok=True) + return reports_dir / f"{project_dir_path.name}_results.pptx" + fallback_dir_path = Path(fallback_dir).expanduser().resolve() + fallback_dir_path.mkdir(parents=True, exist_ok=True) + return fallback_dir_path / f"{fallback_stem}.pptx" + + +def normalize_powerpoint_output_path(path: str | Path) -> Path: + resolved = Path(path).expanduser() + if resolved.suffix.lower() != ".pptx": + resolved = resolved.with_suffix(".pptx") + return resolved.resolve() + + +def export_cluster_dynamics_report_pptx( + *, + result: ClusterDynamicsResult, + selection_summary: str, + result_summary: str, + figure: Figure | None, + output_path: str | Path, + settings: PowerPointExportSettings | None = None, + project_dir: str | Path | None = None, + frames_dir: str | Path | None = None, + progress_callback: ClusterReportProgressCallback | None = None, +) -> ClusterPowerPointExportResult: + export_settings = _normalized_settings(settings) + lifetime_columns = ( + "Label", + "Size", + "Mean lifetime (fs)", + "Std lifetime (fs)", + "Completed", + "Window-truncated", + "Assoc. rate (1/ps)", + "Dissoc. rate (1/ps)", + "Occupancy (%)", + "Mean count/frame", + ) + lifetime_rows = [ + [ + entry.label, + str(entry.cluster_size), + _format_optional_float(entry.mean_lifetime_fs), + _format_optional_float(entry.std_lifetime_fs), + str(entry.completed_lifetime_count), + str(entry.window_truncated_lifetime_count), + f"{entry.association_rate_per_ps:.3f}", + f"{entry.dissociation_rate_per_ps:.3f}", + f"{entry.occupancy_fraction * 100.0:.1f}", + f"{entry.mean_count_per_frame:.3f}", + ] + for entry in result.lifetime_by_label + ] + generated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + cover_lines = [ + f"Generated: {generated_at}", + f"Frames analyzed: {result.analyzed_frames}", + f"Time bins: {result.bin_count}", + "Project directory: " + + ( + "not set" + if project_dir is None + else str(Path(project_dir).resolve()) + ), + "Frames folder: " + + ( + "not set" + if frames_dir is None + else str(Path(frames_dir).resolve()) + ), + ] + text_sections = [ + _TextSection( + title="ClusterDynamics Settings", + subtitle="Selection preview and analysis inputs", + pages=_paginate_text_lines( + selection_summary.splitlines(), + max_lines=18, + wrap_at=60, + ), + placeholder="Selection settings are not available for this result.", + ), + _TextSection( + title="ClusterDynamics Summary", + subtitle="Observed result summary", + pages=_paginate_text_lines( + result_summary.splitlines(), + max_lines=18, + wrap_at=60, + ), + placeholder="Summary information is not available for this result.", + ), + ] + figure_sections = [ + _FigureSection( + title="Cluster Distribution Heatmap", + subtitle="Current clusterdynamics plot settings", + figure=figure, + placeholder="No cluster-distribution plot is available for this result.", + ) + ] + table_sections = [ + _TableSection( + title="Observed Cluster Lifetimes", + subtitle="Lifetime statistics by stoichiometry label", + columns=lifetime_columns, + rows=lifetime_rows, + rows_per_slide=11, + column_width_weights=( + 1.55, + 0.55, + 1.05, + 0.95, + 0.7, + 1.0, + 1.0, + 1.0, + 0.95, + 1.05, + ), + alignments=( + "left", + "center", + "center", + "center", + "center", + "center", + "center", + "center", + "center", + "center", + ), + row_font_size=9.5, + ) + ] + return _export_cluster_report_pptx( + report_title="ClusterDynamics Report", + cover_subtitle="Time-binned cluster-distribution analysis", + cover_lines=cover_lines, + text_sections=text_sections, + figure_sections=figure_sections, + table_sections=table_sections, + output_path=output_path, + settings=export_settings, + progress_callback=progress_callback, + ) + + +def export_cluster_dynamicsai_report_pptx( + *, + result: "ClusterDynamicsMLResult", + selection_summary: str, + result_summary: str, + dynamics_figure: Figure | None, + surrogate_figure: Figure | None, + output_path: str | Path, + settings: PowerPointExportSettings | None = None, + project_dir: str | Path | None = None, + frames_dir: str | Path | None = None, + progress_callback: ClusterReportProgressCallback | None = None, +) -> ClusterPowerPointExportResult: + export_settings = _normalized_settings(settings) + lifetime_columns = ( + "Label", + "Size", + "Mean lifetime (fs)", + "Std lifetime (fs)", + "Completed", + "Window-truncated", + "Assoc. rate (1/ps)", + "Dissoc. rate (1/ps)", + "Occupancy (%)", + "Mean count/frame", + ) + lifetime_rows = [ + [ + entry.label, + str(entry.cluster_size), + _format_optional_float(entry.mean_lifetime_fs), + _format_optional_float(entry.std_lifetime_fs), + str(entry.completed_lifetime_count), + str(entry.window_truncated_lifetime_count), + f"{entry.association_rate_per_ps:.3f}", + f"{entry.dissociation_rate_per_ps:.3f}", + f"{entry.occupancy_fraction * 100.0:.1f}", + f"{entry.mean_count_per_frame:.3f}", + ] + for entry in result.dynamics_result.lifetime_by_label + ] + prediction_columns = ( + "Target nodes", + "Rank", + "Label", + "Share (%)", + "Mean count/frame", + "Mean lifetime (fs)", + "Assoc. rate", + "Dissoc. rate", + "Source", + "Notes", + ) + prediction_rows = [ + [ + str(entry.target_node_count), + str(entry.rank), + entry.label, + f"{entry.predicted_population_share * 100.0:.2f}", + f"{entry.predicted_mean_count_per_frame:.4f}", + f"{entry.predicted_mean_lifetime_fs:.3f}", + f"{entry.predicted_association_rate_per_ps:.3f}", + f"{entry.predicted_dissociation_rate_per_ps:.3f}", + "" if entry.source_label is None else entry.source_label, + entry.notes, + ] + for entry in result.predictions + ] + generated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + cover_lines = [ + f"Generated: {generated_at}", + f"Frames analyzed: {result.dynamics_result.analyzed_frames}", + f"Observed node counts: {result.preview.observed_node_counts or ('n/a',)}", + f"Predicted candidates: {len(result.predictions)}", + "Project directory: " + + ( + "not set" + if project_dir is None + else str(Path(project_dir).resolve()) + ), + "Frames folder: " + + ( + "not set" + if frames_dir is None + else str(Path(frames_dir).resolve()) + ), + ] + text_sections = [ + _TextSection( + title="ClusterDynamicsML Settings", + subtitle="Selection preview and prediction inputs", + pages=_paginate_text_lines( + selection_summary.splitlines(), + max_lines=18, + wrap_at=60, + ), + placeholder="Selection settings are not available for this result.", + ), + _TextSection( + title="ClusterDynamicsML Summary", + subtitle="Observed and predicted result summary", + pages=_paginate_text_lines( + result_summary.splitlines(), + max_lines=18, + wrap_at=60, + ), + placeholder="Summary information is not available for this result.", + ), + ] + figure_sections = [ + _FigureSection( + title="Observed Cluster Distribution", + subtitle="Current clusterdynamics heatmap settings", + figure=dynamics_figure, + placeholder="No observed cluster-distribution plot is available.", + ), + _FigureSection( + title="Surrogate SAXS Comparison", + subtitle="Observed-only and observed + surrogate SAXS traces", + figure=surrogate_figure, + placeholder="No surrogate SAXS plot is available for this result.", + ), + ] + table_sections = [ + _TableSection( + title="Observed Cluster Lifetimes", + subtitle="Lifetime statistics used by the surrogate workflow", + columns=lifetime_columns, + rows=lifetime_rows, + rows_per_slide=11, + column_width_weights=( + 1.55, + 0.55, + 1.05, + 0.95, + 0.7, + 1.0, + 1.0, + 1.0, + 0.95, + 1.05, + ), + alignments=( + "left", + "center", + "center", + "center", + "center", + "center", + "center", + "center", + "center", + "center", + ), + row_font_size=9.5, + ), + _TableSection( + title="Predicted Larger Clusters", + subtitle="Ranked surrogate candidates above the current threshold", + columns=prediction_columns, + rows=prediction_rows, + rows_per_slide=8, + column_width_weights=( + 0.75, + 0.48, + 1.1, + 0.78, + 1.0, + 1.02, + 0.82, + 0.88, + 0.95, + 3.65, + ), + alignments=( + "center", + "center", + "left", + "center", + "center", + "center", + "center", + "center", + "left", + "left", + ), + row_font_size=9.0, + ), + ] + return _export_cluster_report_pptx( + report_title="ClusterDynamicsML Report", + cover_subtitle="Cluster extrapolation and surrogate SAXS analysis", + cover_lines=cover_lines, + text_sections=text_sections, + figure_sections=figure_sections, + table_sections=table_sections, + output_path=output_path, + settings=export_settings, + progress_callback=progress_callback, + ) + + +def _export_cluster_report_pptx( + *, + report_title: str, + cover_subtitle: str, + cover_lines: list[str], + text_sections: Sequence[_TextSection], + figure_sections: Sequence[_FigureSection], + table_sections: Sequence[_TableSection], + output_path: str | Path, + settings: PowerPointExportSettings, + progress_callback: ClusterReportProgressCallback | None, +) -> ClusterPowerPointExportResult: + ( + Presentation, + Inches, + Pt, + RGBColor, + PP_ALIGN, + MSO_VERTICAL_ANCHOR, + MSO_AUTO_SHAPE_TYPE, + ) = _load_pptx_api() + + def rgb_color(value: str): + red, green, blue = tuple( + int(value[index : index + 2], 16) for index in (1, 3, 5) + ) + return RGBColor(red, green, blue) + + def first_run(paragraph): + if paragraph.runs: + return paragraph.runs[0] + return paragraph.add_run() + + def apply_run_style( + run, + *, + font_size: float, + bold: bool = False, + color: str | None = None, + ) -> None: + run.font.name = settings.font_family + run.font.size = Pt(font_size) + run.font.bold = bold + run.font.color.rgb = rgb_color( + settings.text_color if color is None else color + ) + + def set_slide_background(slide) -> None: + slide.background.fill.solid() + slide.background.fill.fore_color.rgb = rgb_color("#FFFFFF") + + def add_title(slide, title: str, subtitle: str | None = None) -> None: + title_box = slide.shapes.add_textbox( + Inches(_SLIDE_LEFT_INCHES), + Inches(_TITLE_TOP_INCHES), + Inches(_SLIDE_CONTENT_WIDTH_INCHES), + Inches(_TITLE_HEIGHT_INCHES), + ) + title_frame = title_box.text_frame + title_frame.word_wrap = False + title_frame.clear() + title_paragraph = title_frame.paragraphs[0] + title_paragraph.text = str(title) + title_paragraph.space_after = Pt(0) + title_paragraph.space_before = Pt(0) + apply_run_style(first_run(title_paragraph), font_size=22, bold=True) + if subtitle: + subtitle_box = slide.shapes.add_textbox( + Inches(_SLIDE_LEFT_INCHES), + Inches(_SUBTITLE_TOP_INCHES), + Inches(_SLIDE_CONTENT_WIDTH_INCHES), + Inches(_SUBTITLE_HEIGHT_INCHES), + ) + subtitle_frame = subtitle_box.text_frame + subtitle_frame.word_wrap = False + subtitle_frame.clear() + subtitle_paragraph = subtitle_frame.paragraphs[0] + subtitle_paragraph.text = str(subtitle) + subtitle_paragraph.space_after = Pt(0) + subtitle_paragraph.space_before = Pt(0) + apply_run_style( + first_run(subtitle_paragraph), + font_size=10, + color="#4B5563", + ) + + def add_text_block( + slide, + *, + left: float, + top: float, + width: float, + height: float, + lines: Sequence[str], + font_size: float, + ) -> None: + textbox = slide.shapes.add_textbox( + Inches(left), + Inches(top), + Inches(width), + Inches(height), + ) + frame = textbox.text_frame + frame.word_wrap = True + frame.clear() + frame.vertical_anchor = MSO_VERTICAL_ANCHOR.TOP + for index, line in enumerate(lines): + paragraph = ( + frame.paragraphs[0] if index == 0 else frame.add_paragraph() + ) + paragraph.text = str(line) + paragraph.space_after = Pt(0) + paragraph.space_before = Pt(0) + apply_run_style(first_run(paragraph), font_size=font_size) + + def add_cover_slide( + title: str, subtitle: str, lines: Sequence[str] + ) -> None: + slide = presentation.slides.add_slide(blank_layout) + set_slide_background(slide) + add_title(slide, title, subtitle) + add_text_block( + slide, + left=_SLIDE_LEFT_INCHES, + top=1.45, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=4.9, + lines=lines, + font_size=15, + ) + register_slide(title) + + def add_picture( + slide, + image_path: Path, + *, + left: float, + top: float, + width: float, + height: float, + ) -> None: + fitted_left, fitted_top, fitted_width, fitted_height = ( + _fit_image_in_box( + image_path, + left=left, + top=top, + max_width=width, + max_height=height, + ) + ) + slide.shapes.add_picture( + str(image_path), + Inches(fitted_left), + Inches(fitted_top), + width=Inches(fitted_width), + height=Inches(fitted_height), + ) + + def add_table_header_rule( + slide, + *, + left: float, + top: float, + width: float, + ) -> None: + rule = slide.shapes.add_shape( + MSO_AUTO_SHAPE_TYPE.RECTANGLE, + Inches(left), + Inches(top), + Inches(width), + Inches(_THICK_RULE_HEIGHT_INCHES), + ) + rule.fill.solid() + rule.fill.fore_color.rgb = rgb_color(settings.table_rule_color) + rule.line.fill.background() + + def style_table_cell( + cell, + *, + text: str, + font_size: float, + fill_color: str, + bold: bool = False, + align=None, + ) -> None: + cell.text = str(text) + cell.fill.solid() + cell.fill.fore_color.rgb = rgb_color(fill_color) + cell.margin_left = Inches(0.035) + cell.margin_right = Inches(0.035) + cell.margin_top = Inches(0.02) + cell.margin_bottom = Inches(0.02) + cell.vertical_anchor = MSO_VERTICAL_ANCHOR.MIDDLE + frame = cell.text_frame + frame.word_wrap = True + paragraph = frame.paragraphs[0] + paragraph.alignment = PP_ALIGN.LEFT if align is None else align + paragraph.space_after = Pt(0) + paragraph.space_before = Pt(0) + apply_run_style(first_run(paragraph), font_size=font_size, bold=bold) + + def resolve_alignment(value: str | None): + if value == "center": + return PP_ALIGN.CENTER + if value == "right": + return PP_ALIGN.RIGHT + return PP_ALIGN.LEFT + + def add_text_section(section: _TextSection) -> None: + pages = section.pages or [[]] + if not pages: + pages = [[]] + total_pages = len(pages) + for page_index, page in enumerate(pages, start=1): + slide = presentation.slides.add_slide(blank_layout) + set_slide_background(slide) + effective_title = _page_title( + section.title, page_index, total_pages + ) + add_title(slide, effective_title, section.subtitle) + content = page or [section.placeholder] + add_text_block( + slide, + left=_SLIDE_LEFT_INCHES, + top=_TEXT_TOP_INCHES, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=_TEXT_HEIGHT_INCHES, + lines=content, + font_size=section.font_size, + ) + register_slide(effective_title) + + def add_figure_section( + section: _FigureSection, + temporary_dir: Path, + ) -> None: + slide = presentation.slides.add_slide(blank_layout) + set_slide_background(slide) + add_title(slide, section.title, section.subtitle) + if section.figure is None: + add_text_block( + slide, + left=_SLIDE_LEFT_INCHES, + top=2.4, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=1.0, + lines=[section.placeholder], + font_size=15, + ) + register_slide(section.title) + return + image_path = temporary_dir / f"{_slugify(section.title)}.png" + _save_figure_image( + section.figure, + image_path, + font_family=settings.font_family, + ) + add_picture( + slide, + image_path, + left=_SLIDE_LEFT_INCHES, + top=_FIGURE_TOP_INCHES, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=_FIGURE_HEIGHT_INCHES, + ) + register_slide(section.title) + + def add_table_section(section: _TableSection) -> None: + chunks = _table_row_chunks(section.rows, section.rows_per_slide) + if not chunks: + chunks = [[]] + total_pages = len(chunks) + for page_index, chunk in enumerate(chunks, start=1): + slide = presentation.slides.add_slide(blank_layout) + set_slide_background(slide) + effective_title = _page_title( + section.title, page_index, total_pages + ) + add_title(slide, effective_title, section.subtitle) + if not chunk: + add_text_block( + slide, + left=_SLIDE_LEFT_INCHES, + top=2.4, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=1.0, + lines=["No data are available for this section."], + font_size=15, + ) + register_slide(effective_title) + continue + add_table_header_rule( + slide, + left=_SLIDE_LEFT_INCHES, + top=_TABLE_TOP_INCHES, + width=_SLIDE_CONTENT_WIDTH_INCHES, + ) + table_height = _TABLE_HEIGHT_INCHES - ( + 0.25 if section.note else 0.0 + ) + table_shape = slide.shapes.add_table( + len(chunk) + 1, + len(section.columns), + Inches(_SLIDE_LEFT_INCHES), + Inches(_TABLE_TOP_INCHES), + Inches(_SLIDE_CONTENT_WIDTH_INCHES), + Inches(table_height), + ) + table = table_shape.table + column_widths = _resolve_column_widths( + section.columns, + chunk, + total_width=_SLIDE_CONTENT_WIDTH_INCHES, + column_width_weights=section.column_width_weights, + ) + for column_index, column_width in enumerate(column_widths): + table.columns[column_index].width = Inches(column_width) + row_height = table_height / max(len(chunk) + 1, 1) + for row_index in range(len(chunk) + 1): + table.rows[row_index].height = Inches(row_height) + for column_index, column_name in enumerate(section.columns): + style_table_cell( + table.cell(0, column_index), + text=column_name, + font_size=section.header_font_size, + fill_color=settings.table_header_fill, + bold=True, + align=PP_ALIGN.CENTER, + ) + for row_index, row in enumerate(chunk, start=1): + row_fill = ( + settings.table_even_row_fill + if row_index % 2 == 1 + else settings.table_odd_row_fill + ) + for column_index, value in enumerate(row): + alignment = None + if section.alignments is not None and column_index < len( + section.alignments + ): + alignment = resolve_alignment( + section.alignments[column_index] + ) + style_table_cell( + table.cell(row_index, column_index), + text=str(value), + font_size=section.row_font_size, + fill_color=row_fill, + align=alignment, + ) + if section.note and page_index == total_pages: + add_text_block( + slide, + left=_SLIDE_LEFT_INCHES, + top=_TABLE_TOP_INCHES + table_height + 0.08, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=0.18, + lines=[section.note], + font_size=9.5, + ) + register_slide(effective_title) + + output_path = normalize_powerpoint_output_path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + appended_to_existing = output_path.is_file() + if appended_to_existing: + presentation = Presentation(str(output_path)) + else: + presentation = Presentation() + presentation.slide_width = Inches(_SLIDE_WIDTH_INCHES) + presentation.slide_height = Inches(_SLIDE_HEIGHT_INCHES) + blank_layout = _best_blank_layout(presentation) + initial_slide_count = len(presentation.slides) + total_slide_count = ( + 1 + + sum(max(len(section.pages), 1) for section in text_sections) + + len(figure_sections) + + sum( + max( + len(_table_row_chunks(section.rows, section.rows_per_slide)), 1 + ) + for section in table_sections + ) + ) + progress = _ReportProgressTracker( + total_slide_count + 2, + progress_callback, + opening_message=f"Generating {report_title}. Please wait...", + ) + progress.advance( + "Opened existing PowerPoint report." + if appended_to_existing + else "Created new PowerPoint report." + ) + slide_index = 0 + + def register_slide(message: str) -> None: + nonlocal slide_index + slide_index += 1 + progress.advance( + f"Built slide {slide_index}/{total_slide_count}: {message}" + ) + + with TemporaryDirectory() as temporary_directory: + temporary_dir = Path(temporary_directory) + add_cover_slide(report_title, cover_subtitle, cover_lines) + for section in text_sections: + add_text_section(section) + for section in figure_sections: + add_figure_section(section, temporary_dir) + for section in table_sections: + add_table_section(section) + presentation.save(str(output_path)) + progress.advance("Saved PowerPoint report.") + return ClusterPowerPointExportResult( + report_path=output_path, + appended_to_existing=appended_to_existing, + added_slide_count=len(presentation.slides) - initial_slide_count, + ) + + +def _best_blank_layout(presentation): + for layout in presentation.slide_layouts: + try: + if len(layout.placeholders) == 0: + return layout + except Exception: + continue + if len(presentation.slide_layouts) > 6: + return presentation.slide_layouts[6] + return presentation.slide_layouts[-1] + + +def _load_pptx_api(): + try: + from pptx import Presentation + from pptx.dml.color import RGBColor + from pptx.enum.shapes import MSO_AUTO_SHAPE_TYPE + from pptx.enum.text import MSO_VERTICAL_ANCHOR, PP_ALIGN + from pptx.util import Inches, Pt + except ImportError as exc: + raise RuntimeError( + "PowerPoint export requires the optional dependency " + "`python-pptx`. Install it and retry." + ) from exc + return ( + Presentation, + Inches, + Pt, + RGBColor, + PP_ALIGN, + MSO_VERTICAL_ANCHOR, + MSO_AUTO_SHAPE_TYPE, + ) + + +def _normalized_settings( + settings: PowerPointExportSettings | None, +) -> PowerPointExportSettings: + if settings is None: + return PowerPointExportSettings() + return PowerPointExportSettings.from_dict(settings.to_dict()) + + +def _paginate_text_lines( + lines: Sequence[str], + *, + max_lines: int, + wrap_at: int, +) -> list[list[str]]: + if max_lines <= 0: + raise ValueError("max_lines must be positive") + wrapped_lines: list[str] = [] + for raw_line in lines: + line = str(raw_line).strip() + if not line: + wrapped_lines.append("") + continue + wrapped_lines.extend( + textwrap.wrap( + line, + width=max(int(wrap_at), 1), + break_long_words=True, + break_on_hyphens=False, + ) + or [line] + ) + if not wrapped_lines: + return [[]] + return [ + wrapped_lines[index : index + max_lines] + for index in range(0, len(wrapped_lines), max_lines) + ] + + +def _table_row_chunks( + rows: Sequence[Sequence[str]], + rows_per_slide: int, +) -> list[list[list[str]]]: + if rows_per_slide <= 0: + raise ValueError("rows_per_slide must be positive") + if not rows: + return [] + return [ + [list(value) for value in rows[index : index + rows_per_slide]] + for index in range(0, len(rows), rows_per_slide) + ] + + +def _page_title(title: str, page_index: int, total_pages: int) -> str: + if total_pages <= 1: + return title + return f"{title} ({page_index}/{total_pages})" + + +def _resolve_column_widths( + columns: Sequence[str], + rows: Sequence[Sequence[str]], + *, + total_width: float, + column_width_weights: Sequence[float] | None = None, +) -> list[float]: + if not columns: + return [] + if column_width_weights is not None: + weights = [max(float(weight), 0.2) for weight in column_width_weights] + else: + weights = [] + for column_index, column_name in enumerate(columns): + max_length = len(str(column_name)) + for row in rows: + if column_index >= len(row): + continue + max_length = max(max_length, len(str(row[column_index]))) + weights.append(max(0.8, min(float(max_length) ** 0.78, 3.2))) + if len(weights) != len(columns): + raise ValueError("column width weights must match the column count") + scale = total_width / sum(weights) + widths = [weight * scale for weight in weights] + widths[-1] += total_width - sum(widths) + return widths + + +def _fit_image_in_box( + image_path: Path, + *, + left: float, + top: float, + max_width: float, + max_height: float, +) -> tuple[float, float, float, float]: + try: + from PIL import Image + except ImportError: + return left, top, max_width, max_height + + with Image.open(image_path) as image: + width_px, height_px = image.size + if width_px <= 0 or height_px <= 0: + return left, top, max_width, max_height + image_ratio = width_px / height_px + box_ratio = max_width / max_height if max_height > 0 else image_ratio + if image_ratio >= box_ratio: + fitted_width = max_width + fitted_height = fitted_width / image_ratio + else: + fitted_height = max_height + fitted_width = fitted_height * image_ratio + fitted_left = left + (max_width - fitted_width) / 2.0 + fitted_top = top + (max_height - fitted_height) / 2.0 + return fitted_left, fitted_top, fitted_width, fitted_height + + +def _save_figure_image( + figure: Figure, + output_path: Path, + *, + font_family: str, +) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + text_artists = list(figure.findobj(match=Text)) + original_font_families = [ + artist.get_fontfamily() for artist in text_artists + ] + try: + for artist in text_artists: + artist.set_fontfamily(font_family) + if figure.canvas is not None: + figure.canvas.draw() + figure.savefig( + output_path, + format="png", + dpi=220, + bbox_inches="tight", + facecolor="white", + edgecolor="white", + ) + finally: + for artist, font_family_value in zip( + text_artists, + original_font_families, + strict=False, + ): + artist.set_fontfamily(font_family_value) + if figure.canvas is not None: + figure.canvas.draw_idle() + + +def _format_optional_float(value: float | None) -> str: + if value is None: + return "n/a" + return f"{value:.6g}" + + +def _slugify(value: str) -> str: + safe = "".join( + character.lower() if character.isalnum() else "_" + for character in value + ) + return safe.strip("_") or "figure" diff --git a/src/saxshell/clusterdynamics/ui/__init__.py b/src/saxshell/clusterdynamics/ui/__init__.py new file mode 100644 index 0000000..2682292 --- /dev/null +++ b/src/saxshell/clusterdynamics/ui/__init__.py @@ -0,0 +1,5 @@ +"""Qt UI for the clusterdynamics application.""" + +from .main_window import ClusterDynamicsMainWindow, launch_clusterdynamics_ui + +__all__ = ["ClusterDynamicsMainWindow", "launch_clusterdynamics_ui"] diff --git a/src/saxshell/clusterdynamics/ui/__main__.py b/src/saxshell/clusterdynamics/ui/__main__.py new file mode 100644 index 0000000..2d69d02 --- /dev/null +++ b/src/saxshell/clusterdynamics/ui/__main__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from .main_window import main + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/saxshell/clusterdynamics/ui/main_window.py b/src/saxshell/clusterdynamics/ui/main_window.py new file mode 100644 index 0000000..a486e6e --- /dev/null +++ b/src/saxshell/clusterdynamics/ui/main_window.py @@ -0,0 +1,1849 @@ +from __future__ import annotations + +import argparse +import sys +from dataclasses import dataclass +from pathlib import Path + +from PySide6.QtCore import QObject, Qt, QThread, Signal, Slot +from PySide6.QtWidgets import ( + QApplication, + QDoubleSpinBox, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QHeaderView, + QLabel, + QLineEdit, + QMainWindow, + QMessageBox, + QProgressBar, + QPushButton, + QScrollArea, + QSpinBox, + QSplitter, + QTableWidget, + QTableWidgetItem, + QTabWidget, + QTextEdit, + QVBoxLayout, + QWidget, +) + +from saxshell.cluster import ( + ExtractedFrameFolderClusterAnalyzer, + PairCutoffDefinitions, + detect_frame_folder_mode, + format_box_dimensions, + format_search_mode_label, + frame_folder_label, +) +from saxshell.cluster.ui.definitions_panel import ClusterDefinitionsPanel +from saxshell.cluster.ui.trajectory_panel import ClusterTrajectoryPanel +from saxshell.clusterdynamics import ( + ClusterDynamicsResult, + ClusterDynamicsSelectionPreview, + ClusterDynamicsWorkflow, +) +from saxshell.clusterdynamics.dataset import ( + export_cluster_dynamics_colormap_csv, + export_cluster_dynamics_lifetime_csv, + load_cluster_dynamics_dataset, + save_cluster_dynamics_dataset, +) +from saxshell.clusterdynamics.report import ( + default_powerpoint_report_path, + export_cluster_dynamics_report_pptx, +) +from saxshell.clusterdynamics.ui.plot_panel import ClusterDynamicsPlotPanel +from saxshell.clusterdynamics.workflow import ( + _resolve_colormap_timestep_settings, +) +from saxshell.saxs.project_manager import ( + PowerPointExportSettings, + SAXSProjectManager, + build_project_paths, +) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) +from saxshell.structure import AtomTypeDefinitions + +_OPEN_WINDOWS: list["ClusterDynamicsMainWindow"] = [] + + +@dataclass(slots=True) +class ClusterDynamicsJobConfig: + """Analysis settings assembled from the UI.""" + + frames_dir: Path + energy_file: Path | None + atom_type_definitions: AtomTypeDefinitions + pair_cutoff_definitions: PairCutoffDefinitions + box_dimensions: tuple[float, float, float] | None + use_pbc: bool + default_cutoff: float | None + shell_levels: tuple[int, ...] + shared_shells: bool + include_shell_atoms_in_stoichiometry: bool + search_mode: str + folder_start_time_fs: float | None + first_frame_time_fs: float + frame_timestep_fs: float + frames_per_colormap_timestep: int + analysis_start_fs: float | None + analysis_stop_fs: float | None + + @property + def colormap_timestep_fs(self) -> float: + return float(self.frame_timestep_fs) * float( + self.frames_per_colormap_timestep + ) + + +class ClusterDynamicsWorker(QObject): + """Background worker for time-binned cluster analysis.""" + + progress = Signal(str) + progress_count = Signal(int, int) + finished = Signal(object) + failed = Signal(str) + + def __init__(self, config: ClusterDynamicsJobConfig) -> None: + super().__init__() + self.config = config + + @Slot() + def run(self) -> None: + try: + workflow = ClusterDynamicsWorkflow( + self.config.frames_dir, + atom_type_definitions=self.config.atom_type_definitions, + pair_cutoff_definitions=self.config.pair_cutoff_definitions, + box_dimensions=self.config.box_dimensions, + use_pbc=self.config.use_pbc, + default_cutoff=self.config.default_cutoff, + shell_levels=self.config.shell_levels, + shared_shells=self.config.shared_shells, + include_shell_atoms_in_stoichiometry=( + self.config.include_shell_atoms_in_stoichiometry + ), + search_mode=self.config.search_mode, + folder_start_time_fs=self.config.folder_start_time_fs, + first_frame_time_fs=self.config.first_frame_time_fs, + frame_timestep_fs=self.config.frame_timestep_fs, + frames_per_colormap_timestep=( + self.config.frames_per_colormap_timestep + ), + analysis_start_fs=self.config.analysis_start_fs, + analysis_stop_fs=self.config.analysis_stop_fs, + energy_file=self.config.energy_file, + ) + preview = workflow.preview_selection() + self.progress.emit( + "Preparing time-binned cluster analysis.\n" + f"Frames selected: {preview.selected_frames}\n" + f"Time bins: {preview.bin_count}\n" + f"Frame timestep: {preview.frame_timestep_fs:.3f} fs\n" + "Frames per colormap timestep: " + f"{preview.frames_per_colormap_timestep}\n" + f"Colormap timestep: {preview.colormap_timestep_fs:.3f} fs" + ) + if preview.energy_file is not None: + self.progress.emit( + f"Will also load CP2K energy data from: {preview.energy_file}" + ) + total_frames = max(preview.selected_frames, 1) + self.progress_count.emit(0, total_frames) + log_interval = ( + 1 if total_frames <= 10 else max(total_frames // 8, 25) + ) + + def on_progress( + processed: int, total: int, frame_name: str + ) -> None: + self.progress_count.emit(processed, total) + should_log = ( + processed == 1 + or processed >= total + or processed % log_interval == 0 + ) + if should_log: + self.progress.emit( + f"Analyzed {processed} of {total} frame(s). " + f"Last frame: {frame_name}." + ) + + result = workflow.analyze(progress_callback=on_progress) + self.finished.emit(result) + except Exception as exc: + self.failed.emit(str(exc)) + + +class ClusterDynamicsTimePanel(QGroupBox): + """Time-axis and colormap-binning controls.""" + + settings_changed = Signal() + + def __init__(self) -> None: + super().__init__("Time Axis") + self._build_ui() + + def _build_ui(self) -> None: + layout = QFormLayout(self) + + self.folder_start_time_spin = self._make_optional_time_spin( + "Auto-populated from mdtrajectory export metadata or a folder " + "name such as splitxyz_f847fs. This value is shown as the " + "folder/start cutoff metadata and is used as a fallback origin " + "when frame filenames do not expose the original source-frame " + "indices." + ) + self.folder_start_time_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + layout.addRow("Folder/start time (fs)", self.folder_start_time_spin) + + self.first_frame_time_spin = QDoubleSpinBox() + self.first_frame_time_spin.setDecimals(3) + self.first_frame_time_spin.setRange(0.0, 10**12) + self.first_frame_time_spin.setSingleStep(1.0) + self.first_frame_time_spin.setToolTip( + "Fallback absolute simulation time assigned to the first " + "extracted frame when the folder does not include mdtrajectory " + "metadata and the frame filenames do not encode source-frame " + "indices." + ) + self.first_frame_time_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + layout.addRow("Fallback start time (fs)", self.first_frame_time_spin) + + self.frame_timestep_spin = QDoubleSpinBox() + self.frame_timestep_spin.setDecimals(3) + self.frame_timestep_spin.setRange(0.001, 10**9) + self.frame_timestep_spin.setValue(0.5) + self.frame_timestep_spin.setSingleStep(0.5) + self.frame_timestep_spin.setToolTip( + "Simulation timestep represented by one source trajectory frame. " + "When extracted frame filenames preserve their original indices, " + "the resolved time axis is frame_index x timestep." + ) + self.frame_timestep_spin.valueChanged.connect( + self._on_colormap_settings_changed + ) + layout.addRow("Frame timestep (fs)", self.frame_timestep_spin) + + self.frames_per_colormap_timestep_spin = QSpinBox() + self.frames_per_colormap_timestep_spin.setRange(1, 10**9) + self.frames_per_colormap_timestep_spin.setValue(1) + self.frames_per_colormap_timestep_spin.setToolTip( + "Number of sampled frames combined into each heatmap timestep." + ) + self.frames_per_colormap_timestep_spin.valueChanged.connect( + self._on_colormap_settings_changed + ) + layout.addRow( + "Frames / colormap timestep", + self.frames_per_colormap_timestep_spin, + ) + + self.colormap_timestep_value = QLineEdit() + self.colormap_timestep_value.setReadOnly(True) + self.colormap_timestep_value.setFocusPolicy(Qt.FocusPolicy.NoFocus) + self.colormap_timestep_value.setAlignment(Qt.AlignmentFlag.AlignRight) + self.colormap_timestep_value.setToolTip( + "Derived heatmap timestep used for the colormap bins." + ) + layout.addRow( + "Colormap timestep used (fs)", + self.colormap_timestep_value, + ) + self._update_colormap_timestep_display() + + self.analysis_start_spin = self._make_optional_time_spin( + "Leave at Auto to start from the first extracted frame time." + ) + self.analysis_start_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + layout.addRow("Analysis start (fs)", self.analysis_start_spin) + + self.analysis_stop_spin = self._make_optional_time_spin( + "Leave at Auto to use the full selected frame range." + ) + self.analysis_stop_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + layout.addRow("Analysis stop (fs)", self.analysis_stop_spin) + + def _make_optional_time_spin(self, tooltip: str) -> QDoubleSpinBox: + spin = QDoubleSpinBox() + spin.setDecimals(3) + spin.setRange(-1.0, 10**12) + spin.setSingleStep(10.0) + spin.setSpecialValueText("Auto") + spin.setValue(-1.0) + spin.setToolTip(tooltip) + return spin + + def _update_colormap_timestep_display(self) -> None: + self.colormap_timestep_value.setText( + f"{self.colormap_timestep_fs():.3f}" + ) + + def _on_colormap_settings_changed(self, _value: float | int) -> None: + self._update_colormap_timestep_display() + self.settings_changed.emit() + + def first_frame_time_fs(self) -> float: + return float(self.first_frame_time_spin.value()) + + def set_first_frame_time_fs( + self, + value: float, + *, + emit_signal: bool = True, + ) -> None: + self.first_frame_time_spin.blockSignals(True) + self.first_frame_time_spin.setValue(float(value)) + self.first_frame_time_spin.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + def folder_start_time_fs(self) -> float | None: + value = float(self.folder_start_time_spin.value()) + return ( + None if value <= self.folder_start_time_spin.minimum() else value + ) + + def set_folder_start_time_fs( + self, + value: float | None, + *, + emit_signal: bool = True, + ) -> None: + self.folder_start_time_spin.blockSignals(True) + self.folder_start_time_spin.setValue( + self.folder_start_time_spin.minimum() + if value is None + else float(value) + ) + self.folder_start_time_spin.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + def frame_timestep_fs(self) -> float: + return float(self.frame_timestep_spin.value()) + + def set_frame_timestep_fs( + self, + value: float, + *, + emit_signal: bool = True, + ) -> None: + self.frame_timestep_spin.blockSignals(True) + self.frame_timestep_spin.setValue(float(value)) + self.frame_timestep_spin.blockSignals(False) + self._update_colormap_timestep_display() + if emit_signal: + self.settings_changed.emit() + + def frames_per_colormap_timestep(self) -> int: + return int(self.frames_per_colormap_timestep_spin.value()) + + def set_frames_per_colormap_timestep( + self, + value: int, + *, + emit_signal: bool = True, + ) -> None: + self.frames_per_colormap_timestep_spin.blockSignals(True) + self.frames_per_colormap_timestep_spin.setValue(max(int(value), 1)) + self.frames_per_colormap_timestep_spin.blockSignals(False) + self._update_colormap_timestep_display() + if emit_signal: + self.settings_changed.emit() + + def colormap_timestep_fs(self) -> float: + return float(self.frame_timestep_spin.value()) * float( + self.frames_per_colormap_timestep_spin.value() + ) + + def analysis_start_fs(self) -> float | None: + value = float(self.analysis_start_spin.value()) + return None if value <= self.analysis_start_spin.minimum() else value + + def set_analysis_start_fs( + self, + value: float | None, + *, + emit_signal: bool = True, + ) -> None: + self.analysis_start_spin.blockSignals(True) + self.analysis_start_spin.setValue( + self.analysis_start_spin.minimum() + if value is None + else float(value) + ) + self.analysis_start_spin.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + def analysis_stop_fs(self) -> float | None: + value = float(self.analysis_stop_spin.value()) + return None if value <= self.analysis_stop_spin.minimum() else value + + def set_analysis_stop_fs( + self, + value: float | None, + *, + emit_signal: bool = True, + ) -> None: + self.analysis_stop_spin.blockSignals(True) + self.analysis_stop_spin.setValue( + self.analysis_stop_spin.minimum() + if value is None + else float(value) + ) + self.analysis_stop_spin.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + +class ClusterDynamicsRunPanel(QGroupBox): + """Panel for preview, optional energy input, and analysis logs.""" + + analyze_requested = Signal() + settings_changed = Signal() + + def __init__(self) -> None: + super().__init__("Run Analysis") + self._build_ui() + + def _build_ui(self) -> None: + layout = QVBoxLayout(self) + form = QFormLayout() + + self.energy_path_edit = QLineEdit() + self.energy_path_edit.setToolTip( + "Optional CP2K .ener file used for the lower time-series subplot." + ) + self.energy_path_edit.textChanged.connect( + lambda _text: self.settings_changed.emit() + ) + form.addRow( + "CP2K .ener file", self._make_file_row(self.energy_path_edit) + ) + + layout.addLayout(form) + + layout.addWidget(QLabel("Selection Preview")) + self.selection_box = QTextEdit() + self.selection_box.setReadOnly(True) + self.selection_box.setMinimumHeight(150) + layout.addWidget(self.selection_box) + + self.analyze_button = QPushButton("Analyze Time-Binned Clusters") + self.analyze_button.clicked.connect( + lambda _checked=False: self.analyze_requested.emit() + ) + layout.addWidget(self.analyze_button) + + self.progress_label = QLabel("Progress: idle") + layout.addWidget(self.progress_label) + + self.progress_bar = QProgressBar() + self.progress_bar.setRange(0, 1) + self.progress_bar.setValue(0) + self.progress_bar.setFormat("%v / %m frames") + layout.addWidget(self.progress_bar) + + layout.addWidget(QLabel("Run Log")) + self.log_box = QTextEdit() + self.log_box.setReadOnly(True) + self.log_box.setMinimumHeight(160) + layout.addWidget(self.log_box) + + def _make_file_row(self, line_edit: QLineEdit) -> QWidget: + widget = QWidget() + row = QHBoxLayout(widget) + row.setContentsMargins(0, 0, 0, 0) + button = QPushButton("Browse") + button.clicked.connect( + lambda _checked=False: self._choose_file(line_edit) + ) + row.addWidget(line_edit) + row.addWidget(button) + return widget + + def _choose_file(self, line_edit: QLineEdit) -> None: + path, _selected_filter = QFileDialog.getOpenFileName( + self, + "Select CP2K .ener file", + "", + "Energy Files (*.ener);;All Files (*)", + ) + if path: + line_edit.setText(path) + + def energy_file(self) -> Path | None: + text = self.energy_path_edit.text().strip() + return Path(text) if text else None + + def set_selection_summary(self, text: str) -> None: + self.selection_box.setPlainText(text) + + def set_log(self, text: str) -> None: + self.log_box.setPlainText(text) + + def append_log(self, text: str) -> None: + message = text.strip() + if not message: + return + existing = self.log_box.toPlainText().strip() + if existing: + self.log_box.append(message) + else: + self.log_box.setPlainText(message) + + def reset_progress(self) -> None: + self.progress_label.setText("Progress: idle") + self.progress_bar.setRange(0, 1) + self.progress_bar.setValue(0) + self.progress_bar.setFormat("%v / %m frames") + + def update_progress(self, processed: int, total: int) -> None: + total = max(int(total), 1) + processed = max(0, min(int(processed), total)) + self.progress_label.setText( + f"Progress: {processed} processed, {max(total - processed, 0)} remaining" + ) + self.progress_bar.setRange(0, total) + self.progress_bar.setValue(processed) + self.progress_bar.setFormat("%v / %m frames") + + +class ClusterDynamicsDatasetPanel(QGroupBox): + """Panel for saving and reopening previously computed datasets.""" + + save_dataset_requested = Signal() + load_dataset_requested = Signal() + save_colormap_requested = Signal() + save_lifetime_requested = Signal() + save_powerpoint_requested = Signal() + settings_changed = Signal() + + def __init__(self) -> None: + super().__init__("Saved Results") + self._build_ui() + + def _build_ui(self) -> None: + layout = QVBoxLayout(self) + + helper = QLabel( + "Save the current analysis result as a reloadable dataset, or " + "open a previously saved dataset. These actions reuse saved " + "results and do not rerun the frame analysis. You can also " + "export the plotted colormap data and lifetime table directly " + "as CSV files. When a project is set, related tools can also " + "cache result bundles in the project's exported-results folder " + "for later reuse." + ) + helper.setWordWrap(True) + layout.addWidget(helper) + + form = QFormLayout() + self.project_dir_edit = QLineEdit() + self.project_dir_edit.setToolTip( + "Optional active SAXSShell project used only to choose the " + "default save/load folder in " + "exported_results/data/clusterdynamics." + ) + self.project_dir_edit.textChanged.connect( + lambda _text: self.settings_changed.emit() + ) + form.addRow( + "Project for defaults", + self._make_dir_row(self.project_dir_edit), + ) + layout.addLayout(form) + + button_row = QHBoxLayout() + self.save_dataset_button = QPushButton("Save Current Result") + self.save_dataset_button.setToolTip( + "Write the current plotted analysis result to a reloadable " + "dataset file." + ) + self.save_dataset_button.clicked.connect( + lambda _checked=False: self.save_dataset_requested.emit() + ) + self.load_dataset_button = QPushButton("Open Saved Result") + self.load_dataset_button.setToolTip( + "Load a previously saved cluster-dynamics dataset without " + "rerunning the frame analysis." + ) + self.load_dataset_button.clicked.connect( + lambda _checked=False: self.load_dataset_requested.emit() + ) + button_row.addWidget(self.save_dataset_button) + button_row.addWidget(self.load_dataset_button) + layout.addLayout(button_row) + + export_row = QHBoxLayout() + self.save_colormap_button = QPushButton("Save Colormap Data") + self.save_colormap_button.setToolTip( + "Write the currently plotted heatmap data to a CSV file using " + "the active display mode and time-unit selections." + ) + self.save_colormap_button.clicked.connect( + lambda _checked=False: self.save_colormap_requested.emit() + ) + self.save_lifetime_button = QPushButton("Save Lifetime Table") + self.save_lifetime_button.setToolTip( + "Write the observed lifetime summary table to a CSV file." + ) + self.save_lifetime_button.clicked.connect( + lambda _checked=False: self.save_lifetime_requested.emit() + ) + export_row.addWidget(self.save_colormap_button) + export_row.addWidget(self.save_lifetime_button) + layout.addLayout(export_row) + + report_row = QHBoxLayout() + self.save_powerpoint_button = QPushButton("Save PowerPoint Report") + self.save_powerpoint_button.setToolTip( + "Generate a PowerPoint summary of the current result and append it " + "to the existing project report when you save over that file." + ) + self.save_powerpoint_button.clicked.connect( + lambda _checked=False: self.save_powerpoint_requested.emit() + ) + report_row.addWidget(self.save_powerpoint_button) + report_row.addStretch(1) + layout.addLayout(report_row) + + def _make_dir_row(self, line_edit: QLineEdit) -> QWidget: + widget = QWidget() + row = QHBoxLayout(widget) + row.setContentsMargins(0, 0, 0, 0) + button = QPushButton("Browse") + button.clicked.connect( + lambda _checked=False: self._choose_directory(line_edit) + ) + row.addWidget(line_edit) + row.addWidget(button) + return widget + + def _choose_directory(self, line_edit: QLineEdit) -> None: + path = QFileDialog.getExistingDirectory( + self, + "Select SAXSShell project directory", + line_edit.text().strip(), + ) + if path: + line_edit.setText(path) + + def project_dir(self) -> Path | None: + text = self.project_dir_edit.text().strip() + return Path(text) if text else None + + def set_project_dir(self, path: Path | None) -> None: + self.project_dir_edit.setText("" if path is None else str(path)) + + +class ClusterDynamicsMainWindow(QMainWindow): + """Main Qt window for time-binned cluster-distribution analysis.""" + + def __init__( + self, + initial_frames_dir: Path | None = None, + initial_energy_file: Path | None = None, + initial_project_dir: Path | None = None, + ) -> None: + super().__init__() + self._project_manager = SAXSProjectManager() + self._last_summary: dict[str, object] | None = None + self._frame_format: str | None = None + self._run_thread: QThread | None = None + self._run_worker: ClusterDynamicsWorker | None = None + self._last_result: ClusterDynamicsResult | None = None + self._last_dataset_file: Path | None = None + self._suspend_preview_refresh = False + self._build_ui() + + if initial_frames_dir is not None: + self.trajectory_panel.frames_dir_edit.setText( + str(initial_frames_dir) + ) + if initial_energy_file is not None: + self.run_panel.energy_path_edit.setText(str(initial_energy_file)) + if initial_project_dir is not None: + self.dataset_panel.set_project_dir(initial_project_dir) + + def _build_ui(self) -> None: + self.setWindowTitle("SAXSShell (clusterdynamics)") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1540, 920) + + central = QWidget() + root = QHBoxLayout(central) + root.setContentsMargins(8, 8, 8, 8) + + splitter = QSplitter(Qt.Orientation.Horizontal) + + left = QWidget() + left_layout = QVBoxLayout(left) + left_layout.setContentsMargins(0, 0, 0, 0) + left_layout.setSpacing(12) + + self.trajectory_panel = ClusterTrajectoryPanel() + self.time_panel = ClusterDynamicsTimePanel() + self.definitions_panel = ClusterDefinitionsPanel() + self.run_panel = ClusterDynamicsRunPanel() + self.dataset_panel = ClusterDynamicsDatasetPanel() + + left_layout.addWidget(self.trajectory_panel) + left_layout.addWidget(self.time_panel) + left_layout.addWidget(self.definitions_panel) + left_layout.addWidget(self.run_panel) + left_layout.addWidget(self.dataset_panel) + left_layout.addStretch(1) + + right = QWidget() + right_layout = QVBoxLayout(right) + right_layout.setContentsMargins(0, 0, 0, 0) + right_layout.setSpacing(12) + + self.plot_panel = ClusterDynamicsPlotPanel() + right_layout.addWidget(self.plot_panel, stretch=3) + + self.results_tabs = QTabWidget() + self.summary_box = QTextEdit() + self.summary_box.setReadOnly(True) + self.label_table = self._build_lifetime_table( + headers=( + "Label", + "Size", + "Mean lifetime (fs)", + "Std lifetime (fs)", + "Completed", + "Window-truncated", + "Assoc. rate (1/ps)", + "Dissoc. rate (1/ps)", + "Occupancy (%)", + "Mean count/frame", + ) + ) + self.results_tabs.addTab(self.summary_box, "Summary") + self.results_tabs.addTab(self.label_table, "Lifetime") + right_layout.addWidget(self.results_tabs, stretch=2) + + splitter.addWidget(self._wrap_scroll_area(left)) + splitter.addWidget(right) + splitter.setSizes([500, 1040]) + + root.addWidget(splitter) + self.setCentralWidget(central) + self.statusBar().showMessage("Ready") + + self.trajectory_panel.inspect_requested.connect( + self.inspect_frames_folder + ) + self.trajectory_panel.frames_dir_changed.connect( + self._on_frames_dir_changed + ) + self.trajectory_panel.settings_changed.connect( + self._refresh_selection_preview + ) + self.time_panel.settings_changed.connect( + self._refresh_selection_preview + ) + self.definitions_panel.settings_changed.connect( + self._refresh_selection_preview + ) + self.run_panel.settings_changed.connect( + self._refresh_selection_preview + ) + self.run_panel.analyze_requested.connect(self.run_analysis) + self.dataset_panel.save_dataset_requested.connect(self.save_dataset) + self.dataset_panel.load_dataset_requested.connect(self.load_dataset) + self.dataset_panel.save_colormap_requested.connect( + self.save_colormap_data + ) + self.dataset_panel.save_lifetime_requested.connect( + self.save_lifetime_table + ) + self.dataset_panel.save_powerpoint_requested.connect( + self.save_powerpoint_report + ) + + self.run_panel.set_selection_summary( + "Select an extracted PDB or XYZ frames folder to preview the " + "time-binned cluster analysis." + ) + self.run_panel.set_log( + "Ready. Load a split frame folder from mdtrajectory, define the " + "cluster rules, then run the time-binned analysis to build the " + "cluster-distribution heatmap and lifetime table." + ) + self._set_frame_format(None) + + def inspect_frames_folder(self) -> None: + frames_dir = self.trajectory_panel.get_frames_dir() + if frames_dir is None: + self._show_error("No extracted frames folder selected.") + return + self._inspect_frames_dir(frames_dir, announce=True) + + def run_analysis(self) -> None: + try: + if self._run_thread is not None: + return + config = self._build_job_config() + self.run_panel.reset_progress() + self.run_panel.set_log( + "Time-binned cluster analysis request received.\n" + f"Frames folder: {config.frames_dir}\n" + f"Mode: {frame_folder_label(self._frame_format or 'pdb')}\n" + f"PBC: {'on' if config.use_pbc else 'off'}\n" + "Search mode: " + f"{format_search_mode_label(config.search_mode)}\n" + f"Frame timestep: {config.frame_timestep_fs:.3f} fs\n" + "Frames per colormap timestep: " + f"{config.frames_per_colormap_timestep}\n" + f"Colormap timestep: {config.colormap_timestep_fs:.3f} fs" + ) + self.plot_panel.set_result(None) + self.summary_box.clear() + self.label_table.setRowCount(0) + self.statusBar().showMessage("Analyzing time-binned clusters...") + self._start_worker(config) + except Exception as exc: + self._handle_error("Cluster dynamics analysis failed", str(exc)) + + def _build_job_config(self) -> ClusterDynamicsJobConfig: + frames_dir = self.trajectory_panel.get_frames_dir() + if frames_dir is None: + raise ValueError("No extracted frames folder selected.") + + atom_type_definitions = self.definitions_panel.atom_type_definitions() + if not atom_type_definitions: + raise ValueError( + "Add at least one atom-type definition before running the analysis." + ) + if not ( + atom_type_definitions.get("node") + or atom_type_definitions.get("linker") + ): + raise ValueError("Define at least one node or linker atom type.") + + pair_cutoff_definitions = ( + self.definitions_panel.pair_cutoff_definitions() + ) + default_cutoff = self.definitions_panel.default_cutoff() + if not pair_cutoff_definitions and default_cutoff is None: + raise ValueError( + "Add at least one pair-cutoff definition or specify a default cutoff." + ) + + use_pbc = self.definitions_panel.use_pbc() + manual_box_dimensions = self.definitions_panel.box_dimensions() + resolved_box_dimensions = manual_box_dimensions + if use_pbc and resolved_box_dimensions is None: + resolved_box_dimensions = self._detected_box_dimensions() + if resolved_box_dimensions is None: + raise ValueError( + "Periodic boundary conditions are enabled, but no box " + "dimensions are available. Enter a manual box or inspect " + "a frames folder with a usable coordinate extent." + ) + + return ClusterDynamicsJobConfig( + frames_dir=frames_dir, + energy_file=self.run_panel.energy_file(), + atom_type_definitions=atom_type_definitions, + pair_cutoff_definitions=pair_cutoff_definitions, + box_dimensions=resolved_box_dimensions, + use_pbc=use_pbc, + default_cutoff=default_cutoff, + shell_levels=self.definitions_panel.shell_growth_levels(), + shared_shells=self.definitions_panel.shared_shells(), + include_shell_atoms_in_stoichiometry=( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + search_mode=self.definitions_panel.search_mode(), + folder_start_time_fs=self.time_panel.folder_start_time_fs(), + first_frame_time_fs=self.time_panel.first_frame_time_fs(), + frame_timestep_fs=self.time_panel.frame_timestep_fs(), + frames_per_colormap_timestep=( + self.time_panel.frames_per_colormap_timestep() + ), + analysis_start_fs=self.time_panel.analysis_start_fs(), + analysis_stop_fs=self.time_panel.analysis_stop_fs(), + ) + + def _start_worker(self, config: ClusterDynamicsJobConfig) -> None: + self._run_thread = QThread(self) + self._run_worker = ClusterDynamicsWorker(config) + self._run_worker.moveToThread(self._run_thread) + + self._run_thread.started.connect(self._run_worker.run) + self._run_worker.progress.connect(self.run_panel.append_log) + self._run_worker.progress_count.connect(self._on_run_progress) + self._run_worker.finished.connect(self._on_run_finished) + self._run_worker.failed.connect(self._on_run_failed) + self._run_worker.finished.connect(self._run_thread.quit) + self._run_worker.failed.connect(self._run_thread.quit) + self._run_thread.finished.connect(self._cleanup_run_thread) + self._run_thread.finished.connect(self._run_thread.deleteLater) + self._run_thread.finished.connect(self._run_worker.deleteLater) + self._run_thread.start() + + def _on_run_progress(self, processed: int, total: int) -> None: + self.run_panel.update_progress(processed, total) + self.statusBar().showMessage( + f"Analyzing time-binned clusters... {processed}/{max(total, 1)} frames" + ) + + def _on_run_finished(self, result: ClusterDynamicsResult) -> None: + self._last_result = result + self.plot_panel.set_result(result) + self.run_panel.update_progress( + result.analyzed_frames, result.analyzed_frames + ) + self.run_panel.append_log( + "Time-binned analysis complete.\n" + f"Frames analyzed: {result.analyzed_frames}\n" + f"Time bins: {result.bin_count}\n" + f"Unique cluster labels: {len(result.cluster_labels)}" + ) + self._populate_summary_box(result) + self._populate_label_table(result) + self.statusBar().showMessage("Cluster dynamics analysis complete") + + def _on_run_failed(self, message: str) -> None: + self.statusBar().showMessage("Cluster dynamics analysis failed") + self._handle_error("Cluster dynamics analysis failed", message) + + def _cleanup_run_thread(self) -> None: + self._run_worker = None + self._run_thread = None + + def save_dataset(self) -> None: + if self._last_result is None: + self._show_error( + "Run an analysis or load a saved dataset before exporting." + ) + return + + default_path = self._default_dataset_file() + path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Save cluster dynamics dataset", + str(default_path), + "JSON Files (*.json);;All Files (*)", + ) + if not path: + return + + saved = save_cluster_dynamics_dataset( + self._last_result, + path, + analysis_settings=self._analysis_settings_payload(), + ) + self._last_dataset_file = saved.dataset_file + self.run_panel.append_log( + "Saved cluster-dynamics dataset to " + f"{saved.dataset_file}\n" + f"Wrote {len(saved.written_files)} file(s)." + ) + self.statusBar().showMessage( + f"Saved cluster dynamics dataset to {saved.dataset_file}" + ) + + def load_dataset(self) -> None: + default_path = ( + self._last_dataset_file + if self._last_dataset_file is not None + else self._default_dataset_file() + ) + path, _selected_filter = QFileDialog.getOpenFileName( + self, + "Load cluster dynamics dataset", + str(default_path), + "JSON Files (*.json);;All Files (*)", + ) + if not path: + return + + loaded = load_cluster_dynamics_dataset(path) + self._last_dataset_file = loaded.dataset_file + self._apply_analysis_settings(loaded.analysis_settings) + self._last_result = loaded.result + self.plot_panel.set_result(loaded.result) + self.run_panel.set_selection_summary( + self._format_preview_text(loaded.result.preview) + ) + self._populate_summary_box(loaded.result) + self._populate_label_table(loaded.result) + self.run_panel.append_log( + "Loaded cluster-dynamics dataset from " + f"{loaded.dataset_file}\n" + f"Frames analyzed: {loaded.result.analyzed_frames}\n" + f"Time bins: {loaded.result.bin_count}" + ) + self.statusBar().showMessage( + f"Loaded cluster dynamics dataset from {loaded.dataset_file}" + ) + + def save_colormap_data(self) -> None: + if self._last_result is None: + self._show_error( + "Run an analysis or load a saved dataset before exporting." + ) + return + + default_path = self._default_export_file("cluster_distribution") + path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Save colormap data", + str(default_path), + "CSV Files (*.csv);;All Files (*)", + ) + if not path: + return + + display_mode = self.plot_panel.display_mode_combo.currentData() + time_unit = self.plot_panel.time_unit_combo.currentData() + saved_path = export_cluster_dynamics_colormap_csv( + self._last_result, + path, + display_mode=( + "fraction" if display_mode is None else str(display_mode) + ), + time_unit="fs" if time_unit is None else str(time_unit), + ) + row_count = ( + len(self._last_result.cluster_labels) * self._last_result.bin_count + ) + self.run_panel.append_log( + "Saved cluster-dynamics colormap data to " + f"{saved_path}\n" + f"Rows written: {row_count}" + ) + self.statusBar().showMessage(f"Saved colormap data to {saved_path}") + + def save_lifetime_table(self) -> None: + if self._last_result is None: + self._show_error( + "Run an analysis or load a saved dataset before exporting." + ) + return + + default_path = self._default_export_file("lifetime") + path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Save lifetime table", + str(default_path), + "CSV Files (*.csv);;All Files (*)", + ) + if not path: + return + + saved_path = export_cluster_dynamics_lifetime_csv( + self._last_result, + path, + ) + self.run_panel.append_log( + "Saved cluster-dynamics lifetime table to " + f"{saved_path}\n" + f"Rows written: {len(self._last_result.lifetime_by_label)}" + ) + self.statusBar().showMessage(f"Saved lifetime table to {saved_path}") + + def save_powerpoint_report(self) -> None: + if self._last_result is None: + self._show_error( + "Run an analysis or load a saved dataset before exporting." + ) + return + + self.plot_panel.set_result(self._last_result) + selection_summary = self.run_panel.selection_box.toPlainText().strip() + if not selection_summary: + selection_summary = self._format_preview_text( + self._last_result.preview + ) + summary_text = self.summary_box.toPlainText().strip() + if not summary_text: + self._populate_summary_box(self._last_result) + summary_text = self.summary_box.toPlainText().strip() + + default_path = self._default_powerpoint_report_file() + path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Save cluster dynamics PowerPoint report", + str(default_path), + "PowerPoint Files (*.pptx);;All Files (*)", + ) + if not path: + return + + self.run_panel.progress_label.setText( + "Progress: generating PowerPoint report" + ) + self.run_panel.progress_bar.setRange(0, 1) + self.run_panel.progress_bar.setValue(0) + self.run_panel.progress_bar.setFormat("%v / %m steps") + try: + export_result = export_cluster_dynamics_report_pptx( + result=self._last_result, + selection_summary=selection_summary, + result_summary=summary_text, + figure=self.plot_panel.figure, + output_path=path, + settings=self._powerpoint_export_settings(), + project_dir=self.dataset_panel.project_dir(), + frames_dir=self.trajectory_panel.get_frames_dir(), + progress_callback=self._on_powerpoint_report_progress, + ) + except Exception as exc: + self.run_panel.progress_label.setText( + "Progress: PowerPoint export failed" + ) + self.run_panel.progress_bar.setRange(0, 1) + self.run_panel.progress_bar.setValue(0) + self.run_panel.progress_bar.setFormat("%v / %m steps") + self._handle_error( + "Cluster dynamics PowerPoint export failed", str(exc) + ) + return + + self.run_panel.progress_label.setText( + "Progress: PowerPoint report saved" + ) + self.run_panel.progress_bar.setValue( + self.run_panel.progress_bar.maximum() + ) + self.run_panel.progress_bar.setFormat("%v / %m steps") + if export_result.appended_to_existing: + self.run_panel.append_log( + "Appended cluster-dynamics report slides to " + f"{export_result.report_path}\n" + f"Slides added: {export_result.added_slide_count}" + ) + else: + self.run_panel.append_log( + "Saved cluster-dynamics PowerPoint report to " + f"{export_result.report_path}\n" + f"Slides written: {export_result.added_slide_count}" + ) + self.statusBar().showMessage( + f"Saved PowerPoint report to {export_result.report_path}" + ) + + def _inspect_frames_dir(self, frames_dir: Path, *, announce: bool) -> None: + self._last_summary = None + try: + analyzer = ExtractedFrameFolderClusterAnalyzer( + frames_dir=frames_dir, + atom_type_definitions={}, + pair_cutoffs_def={}, + ) + self._last_summary = analyzer.inspect() + self._sync_box_dimensions_from_summary(self._last_summary) + self._set_frame_format(self._last_summary.get("frame_format")) + self.trajectory_panel.set_summary(self._last_summary) + if announce: + self.run_panel.append_log( + "Inspection complete. " + f"Detected {self._last_summary['n_frames']} extracted " + "frame(s) in the selected folder." + ) + self.statusBar().showMessage("Inspection complete") + except ValueError as exc: + self._sync_box_dimensions_from_summary(None) + frame_format, detail = self._detect_frame_format(frames_dir) + self._set_frame_format(frame_format) + self.trajectory_panel.set_summary_text(str(exc)) + if detail is not None: + self.trajectory_panel.set_frame_mode(None, detail=detail) + if announce: + self._handle_error("Frames-folder inspection failed", str(exc)) + self._refresh_selection_preview() + + def _on_frames_dir_changed(self, frames_dir: Path | None) -> None: + if self._suspend_preview_refresh: + return + self._last_summary = None + self._last_result = None + self.plot_panel.set_result(None) + self.summary_box.clear() + self.label_table.setRowCount(0) + self.time_panel.set_folder_start_time_fs(None, emit_signal=False) + if frames_dir is None: + self._sync_box_dimensions_from_summary(None) + self._set_frame_format(None) + self.trajectory_panel.set_summary_text("") + self._refresh_selection_preview() + return + self._inspect_frames_dir(frames_dir, announce=False) + + def _refresh_selection_preview(self) -> None: + if self._suspend_preview_refresh: + return + frames_dir = self.trajectory_panel.get_frames_dir() + if frames_dir is None: + self.run_panel.set_selection_summary( + "Select an extracted PDB or XYZ frames folder to preview the " + "time-binned cluster analysis." + ) + return + + warning: str | None = None + try: + workflow = self._build_preview_workflow() + preview = workflow.preview_selection() + if ( + self.time_panel.folder_start_time_fs() is None + and preview.folder_start_time_fs is not None + and preview.folder_start_time_source != "manual field" + ): + self.time_panel.set_folder_start_time_fs( + preview.folder_start_time_fs, + emit_signal=False, + ) + text = self._format_preview_text(preview) + except Exception as exc: + warning = str(exc) + text = ( + "Adjust the current settings to preview the time-binned " + f"analysis.\nValidation warning: {warning}" + ) + self.run_panel.set_selection_summary(text) + + def _build_preview_workflow(self) -> ClusterDynamicsWorkflow: + frames_dir = self.trajectory_panel.get_frames_dir() + if frames_dir is None: + raise ValueError("No extracted frames folder selected.") + + manual_box_dimensions = self.definitions_panel.box_dimensions() + resolved_box_dimensions = manual_box_dimensions + if ( + self.definitions_panel.use_pbc() + and resolved_box_dimensions is None + ): + resolved_box_dimensions = self._detected_box_dimensions() + + return ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=self.definitions_panel.atom_type_definitions(), + pair_cutoff_definitions=self.definitions_panel.pair_cutoff_definitions(), + box_dimensions=resolved_box_dimensions, + use_pbc=self.definitions_panel.use_pbc(), + default_cutoff=self.definitions_panel.default_cutoff(), + shell_levels=self.definitions_panel.shell_growth_levels(), + shared_shells=self.definitions_panel.shared_shells(), + include_shell_atoms_in_stoichiometry=( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + search_mode=self.definitions_panel.search_mode(), + folder_start_time_fs=self.time_panel.folder_start_time_fs(), + first_frame_time_fs=self.time_panel.first_frame_time_fs(), + frame_timestep_fs=self.time_panel.frame_timestep_fs(), + frames_per_colormap_timestep=( + self.time_panel.frames_per_colormap_timestep() + ), + analysis_start_fs=self.time_panel.analysis_start_fs(), + analysis_stop_fs=self.time_panel.analysis_stop_fs(), + energy_file=self.run_panel.energy_file(), + ) + + def _format_preview_text( + self, + preview: ClusterDynamicsSelectionPreview, + ) -> str: + box_label = self._box_dimensions_label() + resolved_box_text = format_box_dimensions( + preview.resolved_box_dimensions + ) + lines = [ + f"Mode: {frame_folder_label(preview.frame_format)}", + f"PBC: {'on' if preview.use_pbc else 'off'}", + "Search mode: " + f"{format_search_mode_label(self.definitions_panel.search_mode())}", + f"Frames in folder: {preview.total_frames}", + f"Frames selected: {preview.selected_frames}", + f"First frame time: {preview.first_frame_time_fs:.3f} fs", + "Time source: " f"{preview.time_source_label}", + f"Frame timestep: {preview.frame_timestep_fs:.3f} fs", + f"Colormap timestep: {preview.colormap_timestep_fs:.3f} fs", + f"Time window: {preview.analysis_start_fs:.3f} to " + f"{preview.analysis_stop_fs:.3f} fs", + f"Time bins: {preview.bin_count}", + f"Shell growth: {self._shell_growth_text()}", + "Stoichiometry bins: " + + ( + "solute + shell atoms" + if self.definitions_panel.include_shell_atoms_in_stoichiometry() + else "solute only" + ), + f"{box_label}: {resolved_box_text}", + ] + if preview.frames_per_colormap_timestep is not None: + lines.insert( + 8, + "Frames per colormap timestep: " + f"{preview.frames_per_colormap_timestep}", + ) + box_source = self._box_dimensions_source() + if box_source is not None: + lines.append(f"Box source: {box_source}") + if preview.folder_start_time_fs is not None: + source_label = ( + f" ({preview.folder_start_time_source})" + if preview.folder_start_time_source + else "" + ) + lines.append( + "Folder/start time: " + f"{preview.folder_start_time_fs:.3f} fs{source_label}" + ) + if preview.first_selected_frame is not None: + lines.append( + "Frame file range: " + f"{preview.first_selected_frame} to {preview.last_selected_frame}" + ) + if preview.first_selected_source_frame_index is not None: + lines.append( + "Source frame index range: " + f"{preview.first_selected_source_frame_index} to " + f"{preview.last_selected_source_frame_index}" + ) + if preview.first_selected_time_fs is not None: + lines.append( + "Selected frame times: " + f"{preview.first_selected_time_fs:.3f} to " + f"{preview.last_selected_time_fs:.3f} fs" + ) + if preview.energy_file is not None: + lines.append(f"Energy overlay: {preview.energy_file}") + if preview.time_warnings: + lines.extend( + f"Warning: {message}" for message in preview.time_warnings + ) + return "\n".join(lines) + + def _populate_summary_box(self, result: ClusterDynamicsResult) -> None: + preview = result.preview + lines = [ + f"Mode: {frame_folder_label(preview.frame_format)}", + f"Frames analyzed: {result.analyzed_frames}", + f"Time bins: {result.bin_count}", + f"Unique cluster labels: {len(result.cluster_labels)}", + f"Frame timestep: {preview.frame_timestep_fs:.3f} fs", + f"Colormap timestep: {preview.colormap_timestep_fs:.3f} fs", + f"Time source: {preview.time_source_label}", + f"Time window: {preview.analysis_start_fs:.3f} to " + f"{preview.analysis_stop_fs:.3f} fs", + "Resolved box dimensions: " + f"{format_box_dimensions(preview.resolved_box_dimensions)}", + f"Total clusters sampled: {int(result.total_clusters_per_frame.sum())}", + ] + if preview.frames_per_colormap_timestep is not None: + lines.insert( + 5, + "Frames per colormap timestep: " + f"{preview.frames_per_colormap_timestep}", + ) + if preview.folder_start_time_fs is not None: + lines.append( + "Folder/start time: " f"{preview.folder_start_time_fs:.3f} fs" + ) + if result.energy_data is not None: + lines.append( + f"Energy points in view: {len(result.energy_series('temperature')[0])}" + ) + if preview.time_warnings: + lines.extend( + f"Warning: {message}" for message in preview.time_warnings + ) + self.summary_box.setPlainText("\n".join(lines)) + + def _populate_label_table(self, result: ClusterDynamicsResult) -> None: + self.label_table.setSortingEnabled(False) + self.label_table.setRowCount(len(result.lifetime_by_label)) + for row, entry in enumerate(result.lifetime_by_label): + values = ( + entry.label, + str(entry.cluster_size), + _format_optional_float(entry.mean_lifetime_fs), + _format_optional_float(entry.std_lifetime_fs), + str(entry.completed_lifetime_count), + str(entry.window_truncated_lifetime_count), + f"{entry.association_rate_per_ps:.3f}", + f"{entry.dissociation_rate_per_ps:.3f}", + f"{entry.occupancy_fraction * 100.0:.1f}", + f"{entry.mean_count_per_frame:.3f}", + ) + for column, value in enumerate(values): + self.label_table.setItem(row, column, QTableWidgetItem(value)) + self.label_table.resizeColumnsToContents() + self.label_table.setSortingEnabled(True) + + @staticmethod + def _build_lifetime_table(headers: tuple[str, ...]) -> QTableWidget: + table = QTableWidget(0, len(headers)) + table.setHorizontalHeaderLabels(list(headers)) + table.verticalHeader().setVisible(False) + table.setAlternatingRowColors(True) + header = table.horizontalHeader() + header.setSectionResizeMode(QHeaderView.ResizeMode.ResizeToContents) + if len(headers) > 0: + header.setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) + return table + + def _default_dataset_dir(self) -> Path: + project_dir = self.dataset_panel.project_dir() + if project_dir is not None: + paths = build_project_paths(project_dir) + target_dir = paths.exported_data_dir / "clusterdynamics" + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir + frames_dir = self.trajectory_panel.get_frames_dir() + if frames_dir is not None: + return frames_dir.parent + return Path.cwd() + + def _default_dataset_file(self) -> Path: + if self._last_dataset_file is not None: + return self._last_dataset_file + frames_dir = self.trajectory_panel.get_frames_dir() + folder_label = "cluster_dynamics" + if frames_dir is not None: + folder_label = frames_dir.name or folder_label + return ( + self._default_dataset_dir() + / f"{folder_label}_cluster_dynamics.json" + ) + + def _default_export_file(self, suffix_label: str) -> Path: + dataset_file = self._default_dataset_file() + return dataset_file.with_name( + f"{dataset_file.stem}_{suffix_label}.csv" + ) + + def _default_powerpoint_report_file(self) -> Path: + frames_dir = self.trajectory_panel.get_frames_dir() + fallback_label = "cluster_dynamics_report" + if frames_dir is not None: + fallback_label = f"{frames_dir.name or 'cluster_dynamics'}_report" + return default_powerpoint_report_path( + project_dir=self.dataset_panel.project_dir(), + fallback_dir=self._default_dataset_dir(), + fallback_stem=fallback_label, + ) + + def _powerpoint_export_settings(self) -> PowerPointExportSettings: + project_dir = self.dataset_panel.project_dir() + if project_dir is None: + return PowerPointExportSettings() + project_file = build_project_paths(project_dir).project_file + if not project_file.is_file(): + return PowerPointExportSettings() + try: + settings = self._project_manager.load_project(project_dir) + except Exception: + return PowerPointExportSettings() + return PowerPointExportSettings.from_dict( + settings.powerpoint_export_settings.to_dict() + ) + + def _on_powerpoint_report_progress( + self, + processed: int, + total: int, + message: str, + ) -> None: + total_steps = max(int(total), 1) + processed_steps = max(0, min(int(processed), total_steps)) + self.run_panel.progress_label.setText( + f"Progress: PowerPoint report {processed_steps}/{total_steps}" + ) + self.run_panel.progress_bar.setRange(0, total_steps) + self.run_panel.progress_bar.setValue(processed_steps) + self.run_panel.progress_bar.setFormat("%v / %m steps") + self.statusBar().showMessage(message) + QApplication.processEvents() + + def _analysis_settings_payload(self) -> dict[str, object]: + frames_dir = self.trajectory_panel.get_frames_dir() + energy_file = self.run_panel.energy_file() + project_dir = self.dataset_panel.project_dir() + return { + "frames_dir": None if frames_dir is None else str(frames_dir), + "energy_file": None if energy_file is None else str(energy_file), + "project_dir": None if project_dir is None else str(project_dir), + "atom_type_definitions": { + atom_type: [ + [element, residue] for element, residue in criteria + ] + for atom_type, criteria in self.definitions_panel.atom_type_definitions().items() + }, + "pair_cutoff_definitions": [ + { + "atom1": atom1, + "atom2": atom2, + "shell_cutoffs": { + str(level): float(cutoff) + for level, cutoff in shell_cutoffs.items() + }, + } + for (atom1, atom2), shell_cutoffs in sorted( + self.definitions_panel.pair_cutoff_definitions().items() + ) + ], + "box_dimensions": ( + None + if self.definitions_panel.box_dimensions() is None + else list(self.definitions_panel.box_dimensions()) + ), + "use_pbc": self.definitions_panel.use_pbc(), + "default_cutoff": self.definitions_panel.default_cutoff(), + "shell_levels": list(self.definitions_panel.shell_growth_levels()), + "shared_shells": self.definitions_panel.shared_shells(), + "include_shell_atoms_in_stoichiometry": ( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + "search_mode": self.definitions_panel.search_mode(), + "folder_start_time_fs": self.time_panel.folder_start_time_fs(), + "first_frame_time_fs": self.time_panel.first_frame_time_fs(), + "frame_timestep_fs": self.time_panel.frame_timestep_fs(), + "frames_per_colormap_timestep": ( + self.time_panel.frames_per_colormap_timestep() + ), + "colormap_timestep_fs": self.time_panel.colormap_timestep_fs(), + "analysis_start_fs": self.time_panel.analysis_start_fs(), + "analysis_stop_fs": self.time_panel.analysis_stop_fs(), + } + + def _apply_analysis_settings(self, payload: dict[str, object]) -> None: + self._suspend_preview_refresh = True + try: + frames_dir = _optional_path(payload.get("frames_dir")) + energy_file = _optional_path(payload.get("energy_file")) + project_dir = _optional_path(payload.get("project_dir")) + + self.trajectory_panel.frames_dir_edit.setText( + "" if frames_dir is None else str(frames_dir) + ) + self.run_panel.energy_path_edit.setText( + "" if energy_file is None else str(energy_file) + ) + self.dataset_panel.set_project_dir(project_dir) + + atom_type_definitions = { + str(atom_type): [ + ( + str(entry[0]), + ( + None + if len(entry) < 2 or entry[1] in {None, ""} + else str(entry[1]) + ), + ) + for entry in criteria + if isinstance(entry, (list, tuple)) and entry + ] + for atom_type, criteria in dict( + payload.get("atom_type_definitions", {}) + ).items() + } + pair_cutoff_definitions = { + (str(entry.get("atom1", "")), str(entry.get("atom2", ""))): { + int(level): float(cutoff) + for level, cutoff in dict( + entry.get("shell_cutoffs", {}) + ).items() + } + for entry in payload.get("pair_cutoff_definitions", []) + if isinstance(entry, dict) + } + self.definitions_panel.load_atom_type_definitions( + atom_type_definitions, + emit_signal=False, + ) + self.definitions_panel.load_pair_cutoff_definitions( + pair_cutoff_definitions, + emit_signal=False, + ) + self.definitions_panel.set_box_dimensions( + _optional_box_dimensions(payload.get("box_dimensions")), + emit_signal=False, + ) + self.definitions_panel.set_use_pbc( + bool(payload.get("use_pbc", False)), + emit_signal=False, + ) + self.definitions_panel.set_default_cutoff( + _optional_float(payload.get("default_cutoff")), + emit_signal=False, + ) + self.definitions_panel.set_shell_growth_levels( + tuple(int(value) for value in payload.get("shell_levels", [])), + emit_signal=False, + ) + self.definitions_panel.set_shared_shells( + bool(payload.get("shared_shells", False)), + emit_signal=False, + ) + self.definitions_panel.set_include_shell_atoms_in_stoichiometry( + bool( + payload.get( + "include_shell_atoms_in_stoichiometry", + False, + ) + ), + emit_signal=False, + ) + self.definitions_panel.set_search_mode( + str(payload.get("search_mode", "kdtree")), + emit_signal=False, + ) + + self.time_panel.set_folder_start_time_fs( + _optional_float(payload.get("folder_start_time_fs")), + emit_signal=False, + ) + self.time_panel.set_first_frame_time_fs( + float(payload.get("first_frame_time_fs", 0.0)), + emit_signal=False, + ) + frame_timestep_fs = float(payload.get("frame_timestep_fs", 0.5)) + self.time_panel.set_frame_timestep_fs( + frame_timestep_fs, + emit_signal=False, + ) + ( + frames_per_colormap_timestep, + colormap_timestep_fs, + ) = _resolve_colormap_timestep_settings( + frame_timestep_fs=frame_timestep_fs, + frames_per_colormap_timestep=payload.get( + "frames_per_colormap_timestep" + ), + colormap_timestep_fs=_optional_float( + payload.get("colormap_timestep_fs") + ), + legacy_bin_size_fs=_optional_float(payload.get("bin_size_fs")), + require_integral_ratio=False, + ) + if frames_per_colormap_timestep is None: + frames_per_colormap_timestep = max( + int( + ( + float(colormap_timestep_fs) + / float(frame_timestep_fs) + ) + + 0.5 + ), + 1, + ) + self.time_panel.set_frames_per_colormap_timestep( + frames_per_colormap_timestep, + emit_signal=False, + ) + self.time_panel.set_analysis_start_fs( + _optional_float(payload.get("analysis_start_fs")), + emit_signal=False, + ) + self.time_panel.set_analysis_stop_fs( + _optional_float(payload.get("analysis_stop_fs")), + emit_signal=False, + ) + finally: + self._suspend_preview_refresh = False + self._refresh_selection_preview() + + def _detected_box_dimensions( + self, + ) -> tuple[float, float, float] | None: + if self._last_summary is None: + return None + value = self._last_summary.get("box_dimensions") + if value is None: + value = self._last_summary.get("estimated_box_dimensions") + if value is None: + return None + return tuple(float(component) for component in value) + + def _box_dimensions_source_kind(self) -> str | None: + if self._last_summary is None: + return None + value = self._last_summary.get("box_dimensions_source_kind") + return None if value is None else str(value) + + def _box_dimensions_source(self) -> str | None: + if self._last_summary is None: + return None + value = self._last_summary.get("box_dimensions_source") + return None if value is None else str(value) + + def _box_dimensions_label(self) -> str: + if self._box_dimensions_source_kind() == "source_filename": + return "Source box dimensions" + return "Estimated box dimensions" + + def _sync_box_dimensions_from_summary( + self, + summary: dict[str, object] | None, + ) -> None: + if summary is None: + self.definitions_panel.set_box_dimensions(None, emit_signal=False) + return + if summary.get("box_dimensions_source_kind") == "source_filename": + value = summary.get("box_dimensions") + if value is not None: + self.definitions_panel.set_box_dimensions( + tuple(float(component) for component in value), + emit_signal=False, + ) + return + self.definitions_panel.set_box_dimensions(None, emit_signal=False) + + def _set_frame_format(self, frame_format: object | None) -> None: + normalized = None if frame_format is None else str(frame_format) + self._frame_format = normalized + self.trajectory_panel.set_frame_mode(normalized) + self.definitions_panel.set_frame_mode(normalized) + + def _detect_frame_format( + self, + frames_dir: Path | None, + ) -> tuple[str | None, str | None]: + if frames_dir is None: + return None, None + try: + frame_format, _frame_paths = detect_frame_folder_mode(frames_dir) + except ValueError as exc: + return None, str(exc) + return frame_format, None + + def _shell_growth_text(self) -> str: + levels = self.definitions_panel.shell_growth_levels() + if not levels: + return "core only" + return ", ".join(str(level) for level in levels) + + def _handle_error(self, title: str, message: str) -> None: + self.run_panel.append_log(f"{title}: {message}") + QMessageBox.critical(self, title, message) + + def _show_error(self, message: str) -> None: + QMessageBox.critical(self, "Error", message) + + @staticmethod + def _wrap_scroll_area(widget: QWidget) -> QScrollArea: + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setWidget(widget) + return scroll_area + + +def _format_optional_float(value: float | None) -> str: + return "n/a" if value is None else f"{value:.3f}" + + +def _optional_float(value: object) -> float | None: + return None if value is None else float(value) + + +def _optional_path(value: object) -> Path | None: + if value is None: + return None + text = str(value).strip() + return None if not text else Path(text) + + +def _optional_box_dimensions( + value: object, +) -> tuple[float, float, float] | None: + if value is None: + return None + components = tuple(float(component) for component in value) + if len(components) != 3: + raise ValueError("Saved box dimensions must contain three values.") + return components + + +def launch_clusterdynamics_ui( + frames_dir: str | Path | None = None, + *, + energy_file: str | Path | None = None, + project_dir: str | Path | None = None, +) -> int: + """Launch the Qt6 cluster-dynamics UI.""" + app = QApplication.instance() + owns_app = app is None + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) + + window = ClusterDynamicsMainWindow( + initial_frames_dir=(None if frames_dir is None else Path(frames_dir)), + initial_energy_file=( + None if energy_file is None else Path(energy_file) + ), + initial_project_dir=( + None if project_dir is None else Path(project_dir) + ), + ) + _OPEN_WINDOWS.append(window) + window.show() + if owns_app: + return app.exec() + return 0 + + +def main(argv: list[str] | None = None) -> int: + """Entry point for launching the cluster-dynamics UI.""" + parser = argparse.ArgumentParser( + prog="clusterdynamics-ui", + description=( + "Launch the SAXSShell clusterdynamics UI for time-binned " + "cluster-distribution analysis on extracted frame folders." + ), + ) + parser.add_argument( + "frames_dir", + nargs="?", + help="Optional extracted frames directory to prefill in the UI.", + ) + parser.add_argument( + "--energy-file", + help="Optional CP2K .ener file to prefill in the UI.", + ) + parser.add_argument( + "--project-dir", + help="Optional SAXSShell project directory to prefill in the UI.", + ) + args = parser.parse_args(argv) + return launch_clusterdynamics_ui( + args.frames_dir, + energy_file=args.energy_file, + project_dir=args.project_dir, + ) + + +__all__ = [ + "ClusterDynamicsJobConfig", + "ClusterDynamicsMainWindow", + "ClusterDynamicsWorker", + "launch_clusterdynamics_ui", + "main", +] diff --git a/src/saxshell/clusterdynamics/ui/plot_panel.py b/src/saxshell/clusterdynamics/ui/plot_panel.py new file mode 100644 index 0000000..0392504 --- /dev/null +++ b/src/saxshell/clusterdynamics/ui/plot_panel.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +import math + +import numpy as np +from matplotlib import colormaps +from matplotlib import colors as mcolors +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.backends.backend_qtagg import ( + NavigationToolbar2QT as NavigationToolbar, +) +from matplotlib.figure import Figure +from PySide6.QtWidgets import ( + QComboBox, + QDoubleSpinBox, + QHBoxLayout, + QLabel, + QVBoxLayout, + QWidget, +) + +from saxshell.clusterdynamics.workflow import ClusterDynamicsResult + +PLOT_COLORMAPS = ("viridis", "magma", "cividis", "inferno", "turbo") +DISPLAY_MODE_LABELS = { + "count": "Counts / bin", + "fraction": "Fraction / bin", + "mean_count": "Mean count / frame", +} +DISPLAY_MODE_COLORBAR_LABELS = { + "count": "Clusters in bin", + "fraction": "Cluster fraction", + "mean_count": "Mean clusters per frame", +} +OVERLAY_SERIES = ( + ("None", None), + ("Temperature", "temperature"), + ("Potential Energy", "potential"), + ("Kinetic Energy", "kinetic"), +) +OVERLAY_COLORS = { + "temperature": "#1f77b4", + "potential": "#2e8b57", + "kinetic": "#c0392b", +} + + +class ClusterDynamicsPlotPanel(QWidget): + """Interactive time-binned cluster heatmap panel.""" + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._result: ClusterDynamicsResult | None = None + self._build_ui() + self.refresh_plot() + + def _build_ui(self) -> None: + root = QVBoxLayout(self) + root.setContentsMargins(0, 0, 0, 0) + root.setSpacing(8) + + controls_widget = QWidget() + controls = QHBoxLayout(controls_widget) + controls.setContentsMargins(0, 0, 0, 0) + controls.setSpacing(8) + + controls.addWidget(QLabel("Heatmap")) + self.display_mode_combo = QComboBox() + for mode, label in DISPLAY_MODE_LABELS.items(): + self.display_mode_combo.addItem(label, mode) + self.display_mode_combo.setCurrentIndex(1) + self.display_mode_combo.currentIndexChanged.connect( + lambda _index: self.refresh_plot() + ) + controls.addWidget(self.display_mode_combo) + + controls.addWidget(QLabel("Units")) + self.time_unit_combo = QComboBox() + self.time_unit_combo.addItem("fs", "fs") + self.time_unit_combo.addItem("ps", "ps") + self.time_unit_combo.currentIndexChanged.connect( + lambda _index: self.refresh_plot() + ) + controls.addWidget(self.time_unit_combo) + + controls.addWidget(QLabel("Colormap")) + self.colormap_combo = QComboBox() + for cmap_name in PLOT_COLORMAPS: + self.colormap_combo.addItem(cmap_name, cmap_name) + self.colormap_combo.currentIndexChanged.connect( + lambda _index: self.refresh_plot() + ) + controls.addWidget(self.colormap_combo) + + controls.addWidget(QLabel("Lower q")) + self.lower_quantile_spin = QDoubleSpinBox() + self.lower_quantile_spin.setDecimals(2) + self.lower_quantile_spin.setRange(0.0, 0.95) + self.lower_quantile_spin.setSingleStep(0.05) + self.lower_quantile_spin.setValue(0.05) + self.lower_quantile_spin.valueChanged.connect( + self._on_quantile_changed + ) + controls.addWidget(self.lower_quantile_spin) + + controls.addWidget(QLabel("Upper q")) + self.upper_quantile_spin = QDoubleSpinBox() + self.upper_quantile_spin.setDecimals(2) + self.upper_quantile_spin.setRange(0.05, 1.0) + self.upper_quantile_spin.setSingleStep(0.05) + self.upper_quantile_spin.setValue(0.95) + self.upper_quantile_spin.valueChanged.connect( + self._on_quantile_changed + ) + controls.addWidget(self.upper_quantile_spin) + + controls.addWidget(QLabel("Overlay")) + self.overlay_combo = QComboBox() + for label, data in OVERLAY_SERIES: + self.overlay_combo.addItem(label, data) + self.overlay_combo.currentIndexChanged.connect( + lambda _index: self.refresh_plot() + ) + controls.addWidget(self.overlay_combo) + controls.addStretch(1) + + root.addWidget(controls_widget) + + self.figure = Figure(figsize=(9.2, 7.2)) + self.canvas = FigureCanvas(self.figure) + root.addWidget(NavigationToolbar(self.canvas, self)) + root.addWidget(self.canvas, stretch=1) + + def set_result(self, result: ClusterDynamicsResult | None) -> None: + self._result = result + has_energy = bool( + result is not None and result.energy_data is not None + ) + self.overlay_combo.setEnabled(has_energy) + if not has_energy: + self.overlay_combo.setCurrentIndex(0) + self.refresh_plot() + + def refresh_plot(self) -> None: + self.figure.clear() + if self._result is None: + axis = self.figure.add_subplot(111) + self._draw_placeholder( + axis, + "Run the analysis to render the cluster-distribution heatmap.", + ) + self.canvas.draw_idle() + return + + if self._result.bin_count == 0: + axis = self.figure.add_subplot(111) + self._draw_placeholder( + axis, + "No time bins are available for the current selection.", + ) + self.canvas.draw_idle() + return + + matrix = self._result.matrix(self._display_mode()) + if matrix.size == 0 or len(self._result.cluster_labels) == 0: + axis = self.figure.add_subplot(111) + self._draw_placeholder( + axis, + "No clusters were detected in the selected time window.", + ) + self.canvas.draw_idle() + return + + overlay_name = self.overlay_combo.currentData() + show_overlay = bool( + overlay_name is not None and self._result.energy_data is not None + ) + + if show_overlay: + grid = self.figure.add_gridspec( + 2, + 1, + height_ratios=[4.0, 1.2], + hspace=0.08, + ) + heatmap_axis = self.figure.add_subplot(grid[0, 0]) + overlay_axis = self.figure.add_subplot( + grid[1, 0], + sharex=heatmap_axis, + ) + else: + heatmap_axis = self.figure.add_subplot(111) + overlay_axis = None + + time_unit = self.time_unit_combo.currentData() + time_edges = self._result.time_edges(time_unit) + cmap = colormaps[self.colormap_combo.currentData()] + norm = self._quantile_norm(matrix) + + image = heatmap_axis.imshow( + matrix, + aspect="auto", + origin="lower", + interpolation="nearest", + extent=( + float(time_edges[0]), + float(time_edges[-1]), + -0.5, + len(self._result.cluster_labels) - 0.5, + ), + cmap=cmap, + norm=norm, + ) + colorbar = self.figure.colorbar(image, ax=heatmap_axis, pad=0.02) + colorbar.set_label(DISPLAY_MODE_COLORBAR_LABELS[self._display_mode()]) + + label_step = max( + 1, + int(math.ceil(len(self._result.cluster_labels) / 24)), + ) + tick_positions = np.arange( + 0, len(self._result.cluster_labels), label_step + ) + heatmap_axis.set_yticks(tick_positions) + heatmap_axis.set_yticklabels( + [self._result.cluster_labels[index] for index in tick_positions] + ) + heatmap_axis.set_ylabel("Cluster label") + heatmap_axis.set_xlim(float(time_edges[0]), float(time_edges[-1])) + heatmap_axis.set_title( + "Time-Binned Cluster Distribution " + f"({DISPLAY_MODE_LABELS[self._display_mode()]})" + ) + if overlay_axis is None: + heatmap_axis.set_xlabel(f"Time ({time_unit})") + else: + heatmap_axis.tick_params(labelbottom=False) + + if overlay_axis is not None and overlay_name is not None: + x_values, y_values, y_label = self._result.energy_series( + overlay_name, + unit=time_unit, + ) + overlay_axis.plot( + x_values, + y_values, + color=OVERLAY_COLORS.get(overlay_name, "#333333"), + linewidth=1.5, + ) + overlay_axis.set_ylabel(y_label) + overlay_axis.set_xlabel(f"Time ({time_unit})") + overlay_axis.grid(alpha=0.25, linestyle=":") + + self.figure.tight_layout() + self.canvas.draw_idle() + + def _display_mode(self) -> str: + value = self.display_mode_combo.currentData() + return "fraction" if value is None else str(value) + + def _on_quantile_changed(self) -> None: + lower = self.lower_quantile_spin.value() + upper = self.upper_quantile_spin.value() + if lower >= upper: + if self.sender() is self.lower_quantile_spin: + self.upper_quantile_spin.blockSignals(True) + self.upper_quantile_spin.setValue(min(lower + 0.05, 1.0)) + self.upper_quantile_spin.blockSignals(False) + else: + self.lower_quantile_spin.blockSignals(True) + self.lower_quantile_spin.setValue(max(upper - 0.05, 0.0)) + self.lower_quantile_spin.blockSignals(False) + self.refresh_plot() + + def _quantile_norm(self, matrix: np.ndarray) -> mcolors.Normalize: + values = np.asarray(matrix, dtype=float) + finite = values[np.isfinite(values)] + if finite.size == 0: + return mcolors.Normalize(vmin=0.0, vmax=1.0) + + positive = finite[finite > 0.0] + if positive.size: + finite = positive + + lower_q = float(self.lower_quantile_spin.value()) + upper_q = float(self.upper_quantile_spin.value()) + vmin = float(np.quantile(finite, lower_q)) + vmax = float(np.quantile(finite, upper_q)) + if vmax <= vmin: + vmin = float(np.min(finite)) + vmax = float(np.max(finite)) + if vmax <= vmin: + vmax = vmin + 1.0 + return mcolors.Normalize(vmin=vmin, vmax=vmax) + + @staticmethod + def _draw_placeholder(axis, message: str) -> None: + axis.text( + 0.5, + 0.5, + message, + ha="center", + va="center", + transform=axis.transAxes, + ) + axis.set_xticks([]) + axis.set_yticks([]) + axis.set_frame_on(False) + + +__all__ = ["ClusterDynamicsPlotPanel"] diff --git a/src/saxshell/clusterdynamics/workflow.py b/src/saxshell/clusterdynamics/workflow.py new file mode 100644 index 0000000..87bb47a --- /dev/null +++ b/src/saxshell/clusterdynamics/workflow.py @@ -0,0 +1,1233 @@ +from __future__ import annotations + +import json +import re +from collections import Counter, defaultdict, deque +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import numpy as np + +from saxshell.cluster import ( + SEARCH_MODE_KDTREE, + ClusterNetwork, + FrameClusterResult, + PairCutoffDefinitions, + XYZClusterNetwork, + XYZStructure, + detect_frame_folder_mode, + normalize_search_mode, +) +from saxshell.cluster.clusternetwork import stoichiometry_label +from saxshell.cluster.workflow import ClusterWorkflow +from saxshell.mdtrajectory.frame.cp2k_ener import CP2KEnergyData +from saxshell.mdtrajectory.workflow import EXPORT_METADATA_FILENAME +from saxshell.structure import AtomTypeDefinitions, PDBStructure + +DisplayMode = Literal["count", "fraction", "mean_count"] +TimeUnit = Literal["fs", "ps"] +EnergySeriesName = Literal["kinetic", "temperature", "potential"] +_FRAME_FILENAME_PATTERN = re.compile( + r"frame_(?P\d+)\.(?:xyz|pdb)$", + re.IGNORECASE, +) +_FOLDER_START_TIME_PATTERN = re.compile( + r"(?:^|_)f(?P\d+(?:_\d+)?)fs(?:\d{4})?$", + re.IGNORECASE, +) + + +@dataclass(slots=True) +class ClusterLifetimeSummary: + """Lifetime summary for one stoichiometry label.""" + + label: str + cluster_size: int + total_observations: int + occupied_frames: int + mean_count_per_frame: float + occupancy_fraction: float + association_events: int + dissociation_events: int + association_rate_per_ps: float + dissociation_rate_per_ps: float + completed_lifetime_count: int + window_truncated_lifetime_count: int + mean_lifetime_fs: float | None + std_lifetime_fs: float | None + + +@dataclass(slots=True) +class ClusterSizeLifetimeSummary: + """Lifetime summary aggregated by cluster size.""" + + cluster_size: int + total_observations: int + occupied_frames: int + mean_count_per_frame: float + occupancy_fraction: float + association_events: int + dissociation_events: int + association_rate_per_ps: float + dissociation_rate_per_ps: float + completed_lifetime_count: int + window_truncated_lifetime_count: int + mean_lifetime_fs: float | None + std_lifetime_fs: float | None + + +@dataclass(slots=True) +class ClusterDynamicsSelectionPreview: + """Preview of the extracted frames selected for time-bin + analysis.""" + + summary: dict[str, object] + frame_format: str + resolved_box_dimensions: tuple[float, float, float] | None + use_pbc: bool + first_frame_time_fs: float + frame_timestep_fs: float + frames_per_colormap_timestep: int | None + colormap_timestep_fs: float + analysis_start_fs: float + analysis_stop_fs: float + first_selected_time_fs: float | None + last_selected_time_fs: float | None + selected_frame_indices: tuple[int, ...] + selected_frame_names: tuple[str, ...] + selected_source_frame_indices: tuple[int | None, ...] + energy_file: Path | None + folder_start_time_fs: float | None = None + folder_start_time_source: str | None = None + time_source_label: str = "Sequential frame order" + time_warnings: tuple[str, ...] = () + + @property + def total_frames(self) -> int: + return int(self.summary["n_frames"]) + + @property + def selected_frames(self) -> int: + return len(self.selected_frame_indices) + + @property + def first_selected_frame(self) -> str | None: + return ( + self.selected_frame_names[0] if self.selected_frame_names else None + ) + + @property + def last_selected_frame(self) -> str | None: + return ( + self.selected_frame_names[-1] + if self.selected_frame_names + else None + ) + + @property + def first_selected_source_frame_index(self) -> int | None: + return ( + self.selected_source_frame_indices[0] + if self.selected_source_frame_indices + else None + ) + + @property + def last_selected_source_frame_index(self) -> int | None: + return ( + self.selected_source_frame_indices[-1] + if self.selected_source_frame_indices + else None + ) + + @property + def bin_count(self) -> int: + return max( + len( + _build_bin_edges( + self.analysis_start_fs, + self.analysis_stop_fs, + self.colormap_timestep_fs, + ) + ) + - 1, + 0, + ) + + @property + def bin_size_fs(self) -> float: + return self.colormap_timestep_fs + + def to_dict(self) -> dict[str, object]: + return { + "frame_format": self.frame_format, + "total_frames": self.total_frames, + "selected_frames": self.selected_frames, + "resolved_box_dimensions": self.resolved_box_dimensions, + "use_pbc": self.use_pbc, + "first_frame_time_fs": self.first_frame_time_fs, + "frame_timestep_fs": self.frame_timestep_fs, + "frames_per_colormap_timestep": self.frames_per_colormap_timestep, + "colormap_timestep_fs": self.colormap_timestep_fs, + "bin_size_fs": self.colormap_timestep_fs, + "analysis_start_fs": self.analysis_start_fs, + "analysis_stop_fs": self.analysis_stop_fs, + "folder_start_time_fs": self.folder_start_time_fs, + "folder_start_time_source": self.folder_start_time_source, + "time_source_label": self.time_source_label, + "time_warnings": list(self.time_warnings), + "first_selected_time_fs": self.first_selected_time_fs, + "last_selected_time_fs": self.last_selected_time_fs, + "first_selected_frame": self.first_selected_frame, + "last_selected_frame": self.last_selected_frame, + "first_selected_source_frame_index": ( + self.first_selected_source_frame_index + ), + "last_selected_source_frame_index": ( + self.last_selected_source_frame_index + ), + "bin_count": self.bin_count, + "energy_file": ( + None if self.energy_file is None else str(self.energy_file) + ), + } + + +@dataclass(slots=True) +class _SeriesLifetimeMetrics: + total_observations: int + occupied_frames: int + mean_count_per_frame: float + occupancy_fraction: float + association_events: int + dissociation_events: int + association_rate_per_ps: float + dissociation_rate_per_ps: float + completed_lifetimes_fs: tuple[float, ...] + window_truncated_lifetimes_fs: tuple[float, ...] + mean_lifetime_fs: float | None + std_lifetime_fs: float | None + + +@dataclass(slots=True) +class _FrameTimeAxis: + frame_times_fs: np.ndarray + source_frame_indices: tuple[int | None, ...] + time_source_label: str + folder_start_time_fs: float | None + folder_start_time_source: str | None + warnings: tuple[str, ...] + + +@dataclass(slots=True) +class ClusterDynamicsResult: + """Computed time-binned cluster-distribution analysis.""" + + preview: ClusterDynamicsSelectionPreview + frame_results: tuple[FrameClusterResult, ...] + bin_edges_fs: np.ndarray + frames_per_bin: np.ndarray + total_clusters_per_bin: np.ndarray + cluster_labels: tuple[str, ...] + cluster_sizes: dict[str, int] + raw_count_matrix: np.ndarray + fraction_matrix: np.ndarray + mean_count_matrix: np.ndarray + frame_count_matrix: np.ndarray + total_clusters_per_frame: np.ndarray + lifetime_by_label: tuple[ClusterLifetimeSummary, ...] + lifetime_by_size: tuple[ClusterSizeLifetimeSummary, ...] + energy_data: CP2KEnergyData | None = None + + @property + def analyzed_frames(self) -> int: + return len(self.frame_results) + + @property + def bin_count(self) -> int: + return int(len(self.bin_edges_fs) - 1) + + @property + def bin_centers_fs(self) -> np.ndarray: + return (self.bin_edges_fs[:-1] + self.bin_edges_fs[1:]) / 2.0 + + @property + def frame_times_fs(self) -> np.ndarray: + if not self.frame_results: + return np.zeros(0, dtype=float) + return np.asarray( + [float(frame.time_fs or 0.0) for frame in self.frame_results], + dtype=float, + ) + + def matrix(self, mode: DisplayMode = "fraction") -> np.ndarray: + if mode == "count": + return np.asarray(self.raw_count_matrix, dtype=float) + if mode == "mean_count": + return np.asarray(self.mean_count_matrix, dtype=float) + return np.asarray(self.fraction_matrix, dtype=float) + + def time_edges(self, unit: TimeUnit = "fs") -> np.ndarray: + return self.bin_edges_fs / _time_scale(unit) + + def bin_centers(self, unit: TimeUnit = "fs") -> np.ndarray: + return self.bin_centers_fs / _time_scale(unit) + + def energy_series( + self, + series_name: EnergySeriesName, + *, + unit: TimeUnit = "fs", + ) -> tuple[np.ndarray, np.ndarray, str]: + if self.energy_data is None: + raise ValueError("No CP2K .ener file was loaded for this result.") + + time_fs = np.asarray(self.energy_data.time_fs, dtype=float) + mask = (time_fs >= float(self.bin_edges_fs[0])) & ( + time_fs <= float(self.bin_edges_fs[-1]) + ) + label_map = { + "kinetic": "Kinetic Energy", + "temperature": "Temperature (K)", + "potential": "Potential Energy", + } + value = np.asarray(getattr(self.energy_data, series_name), dtype=float) + return ( + time_fs[mask] / _time_scale(unit), + value[mask], + label_map[series_name], + ) + + +class ClusterDynamicsWorkflow: + """Headless workflow for time-binned cluster-distribution + analysis.""" + + def __init__( + self, + frames_dir: str | Path, + *, + atom_type_definitions: AtomTypeDefinitions, + pair_cutoff_definitions: PairCutoffDefinitions, + box_dimensions: tuple[float, float, float] | None = None, + use_pbc: bool = False, + default_cutoff: float | None = None, + shell_levels: tuple[int, ...] = (), + shared_shells: bool = False, + include_shell_atoms_in_stoichiometry: bool = False, + search_mode: str = SEARCH_MODE_KDTREE, + folder_start_time_fs: float | None = None, + first_frame_time_fs: float = 0.0, + frame_timestep_fs: float = 0.5, + frames_per_colormap_timestep: int | None = None, + colormap_timestep_fs: float | None = None, + bin_size_fs: float | None = None, + analysis_start_fs: float | None = None, + analysis_stop_fs: float | None = None, + energy_file: str | Path | None = None, + ) -> None: + self.frames_dir = Path(frames_dir) + self.atom_type_definitions = atom_type_definitions + self.pair_cutoff_definitions = pair_cutoff_definitions + self.box_dimensions = box_dimensions + self.use_pbc = bool(use_pbc) + self.default_cutoff = default_cutoff + self.shell_levels = shell_levels + self.shared_shells = bool(shared_shells) + self.include_shell_atoms_in_stoichiometry = bool( + include_shell_atoms_in_stoichiometry + ) + self.search_mode = normalize_search_mode(search_mode) + self.folder_start_time_fs = ( + None + if folder_start_time_fs is None + else float(folder_start_time_fs) + ) + self.fallback_first_frame_time_fs = float(first_frame_time_fs) + self.frame_timestep_fs = _validate_positive_number( + frame_timestep_fs, + label="Frame timestep", + ) + ( + self.frames_per_colormap_timestep, + self.colormap_timestep_fs, + ) = _resolve_colormap_timestep_settings( + frame_timestep_fs=self.frame_timestep_fs, + frames_per_colormap_timestep=frames_per_colormap_timestep, + colormap_timestep_fs=colormap_timestep_fs, + legacy_bin_size_fs=bin_size_fs, + require_integral_ratio=True, + ) + self.bin_size_fs = self.colormap_timestep_fs + self.analysis_start_fs = ( + None if analysis_start_fs is None else float(analysis_start_fs) + ) + self.analysis_stop_fs = ( + None if analysis_stop_fs is None else float(analysis_stop_fs) + ) + self.energy_file = None if energy_file is None else Path(energy_file) + self._cluster_workflow = ClusterWorkflow( + frames_dir=self.frames_dir, + atom_type_definitions=self.atom_type_definitions, + pair_cutoff_definitions=self.pair_cutoff_definitions, + box_dimensions=self.box_dimensions, + use_pbc=self.use_pbc, + default_cutoff=self.default_cutoff, + shell_levels=self.shell_levels, + shared_shells=self.shared_shells, + include_shell_atoms_in_stoichiometry=( + self.include_shell_atoms_in_stoichiometry + ), + search_mode=self.search_mode, + ) + self._cached_preview: ClusterDynamicsSelectionPreview | None = None + self._cached_energy: CP2KEnergyData | None = None + + def inspect(self) -> dict[str, object]: + return self._cluster_workflow.inspect() + + def load_energy(self) -> CP2KEnergyData: + if self.energy_file is None: + raise ValueError("No CP2K .ener file was provided.") + if self._cached_energy is None: + self._cached_energy = CP2KEnergyData.from_file(self.energy_file) + return self._cached_energy + + def preview_selection(self) -> ClusterDynamicsSelectionPreview: + if self._cached_preview is not None: + return self._cached_preview + + summary = self.inspect() + frame_format, frame_paths = detect_frame_folder_mode(self.frames_dir) + time_axis = _infer_frame_time_axis( + self.frames_dir, + frame_paths, + frame_timestep_fs=self.frame_timestep_fs, + folder_start_time_fs=self.folder_start_time_fs, + fallback_first_frame_time_fs=self.fallback_first_frame_time_fs, + ) + frame_times_fs = time_axis.frame_times_fs + resolved_box = self._cluster_workflow.resolve_box_dimensions( + box_dimensions=self.box_dimensions, + use_pbc=self.use_pbc, + ) + ( + analysis_start_fs, + analysis_stop_fs, + selected_indices, + ) = _resolve_time_window( + frame_times_fs=frame_times_fs, + analysis_start_fs=self.analysis_start_fs, + analysis_stop_fs=self.analysis_stop_fs, + frame_timestep_fs=self.frame_timestep_fs, + ) + selected_names = tuple( + frame_paths[index].name for index in selected_indices + ) + selected_source_frame_indices = tuple( + time_axis.source_frame_indices[index] for index in selected_indices + ) + first_selected_time_fs = ( + None + if not selected_indices + else float(frame_times_fs[selected_indices[0]]) + ) + last_selected_time_fs = ( + None + if not selected_indices + else float(frame_times_fs[selected_indices[-1]]) + ) + self._cached_preview = ClusterDynamicsSelectionPreview( + summary=summary, + frame_format=str(frame_format), + resolved_box_dimensions=resolved_box, + use_pbc=self.use_pbc, + first_frame_time_fs=( + float(frame_times_fs[0]) + if frame_times_fs.size + else ( + time_axis.folder_start_time_fs + if time_axis.folder_start_time_fs is not None + else self.fallback_first_frame_time_fs + ) + ), + frame_timestep_fs=self.frame_timestep_fs, + frames_per_colormap_timestep=self.frames_per_colormap_timestep, + colormap_timestep_fs=self.colormap_timestep_fs, + analysis_start_fs=analysis_start_fs, + analysis_stop_fs=analysis_stop_fs, + first_selected_time_fs=first_selected_time_fs, + last_selected_time_fs=last_selected_time_fs, + selected_frame_indices=selected_indices, + selected_frame_names=selected_names, + selected_source_frame_indices=selected_source_frame_indices, + energy_file=self.energy_file, + folder_start_time_fs=time_axis.folder_start_time_fs, + folder_start_time_source=time_axis.folder_start_time_source, + time_source_label=time_axis.time_source_label, + time_warnings=time_axis.warnings, + ) + return self._cached_preview + + def analyze( + self, + *, + progress_callback: callable | None = None, + ) -> ClusterDynamicsResult: + preview = self.preview_selection() + if preview.selected_frames == 0: + raise ValueError( + "No extracted frames fall within the selected time window." + ) + + frame_format, frame_paths = detect_frame_folder_mode(self.frames_dir) + selected_paths = [ + frame_paths[index] for index in preview.selected_frame_indices + ] + frame_times_fs = _infer_frame_time_axis( + self.frames_dir, + frame_paths, + frame_timestep_fs=preview.frame_timestep_fs, + folder_start_time_fs=preview.folder_start_time_fs, + fallback_first_frame_time_fs=self.fallback_first_frame_time_fs, + ).frame_times_fs + selected_times = frame_times_fs[list(preview.selected_frame_indices)] + + frame_results: list[FrameClusterResult] = [] + per_frame_counts: list[Counter[str]] = [] + label_sizes: dict[str, int] = {} + total_frames = len(selected_paths) + + for processed, (frame_index, frame_path, time_fs) in enumerate( + zip( + preview.selected_frame_indices, + selected_paths, + selected_times, + strict=False, + ), + start=1, + ): + network = self._build_network( + frame_format=str(frame_format), + frame_path=frame_path, + resolved_box_dimensions=preview.resolved_box_dimensions, + ) + clusters = network.find_clusters( + shell_levels=self.shell_levels, + shared_shells=self.shared_shells, + ) + frame_results.append( + FrameClusterResult( + frame_index=int(frame_index), + time_fs=float(time_fs), + clusters=clusters, + ) + ) + counts = Counter[str]() + for cluster in clusters: + label = stoichiometry_label(cluster.stoichiometry) + counts[label] += 1 + label_sizes[label] = max( + label_sizes.get(label, 0), + sum( + int(value) for value in cluster.stoichiometry.values() + ), + ) + per_frame_counts.append(counts) + if progress_callback is not None: + progress_callback(processed, total_frames, frame_path.name) + + cluster_labels = tuple( + sorted(label_sizes, key=lambda label: (label_sizes[label], label)) + ) + label_index = { + label: index for index, label in enumerate(cluster_labels) + } + bin_edges_fs = _build_bin_edges( + preview.analysis_start_fs, + preview.analysis_stop_fs, + preview.colormap_timestep_fs, + ) + raw_count_matrix = np.zeros( + (len(cluster_labels), len(bin_edges_fs) - 1), + dtype=float, + ) + frame_count_matrix = np.zeros( + (len(cluster_labels), len(frame_results)), + dtype=float, + ) + frames_per_bin = np.zeros(len(bin_edges_fs) - 1, dtype=float) + total_clusters_per_bin = np.zeros(len(bin_edges_fs) - 1, dtype=float) + total_clusters_per_frame = np.zeros(len(frame_results), dtype=float) + + if frame_results: + selected_frame_times_fs = np.asarray( + [float(frame.time_fs or 0.0) for frame in frame_results], + dtype=float, + ) + bin_indices = _assign_bins(selected_frame_times_fs, bin_edges_fs) + else: + selected_frame_times_fs = np.zeros(0, dtype=float) + bin_indices = np.zeros(0, dtype=int) + + for frame_position, counts in enumerate(per_frame_counts): + bin_index = int(bin_indices[frame_position]) + frames_per_bin[bin_index] += 1.0 + total_clusters = float(sum(counts.values())) + total_clusters_per_bin[bin_index] += total_clusters + total_clusters_per_frame[frame_position] = total_clusters + for label, count in counts.items(): + row = label_index[label] + raw_count_matrix[row, bin_index] += float(count) + frame_count_matrix[row, frame_position] = float(count) + + fraction_matrix = _safe_divide( + raw_count_matrix, + total_clusters_per_bin[np.newaxis, :], + ) + mean_count_matrix = _safe_divide( + raw_count_matrix, + frames_per_bin[np.newaxis, :], + ) + lifetime_by_label = tuple( + self._summarize_label_series( + label=label, + cluster_size=label_sizes[label], + count_series=frame_count_matrix[label_index[label], :], + frame_times_fs=selected_frame_times_fs, + observation_start_fs=preview.analysis_start_fs, + observation_stop_fs=preview.analysis_stop_fs, + ) + for label in cluster_labels + ) + + size_series: dict[int, np.ndarray] = defaultdict( + lambda: np.zeros(len(frame_results), dtype=float) + ) + for label, size in label_sizes.items(): + size_series[size] += frame_count_matrix[label_index[label], :] + lifetime_by_size = tuple( + self._summarize_size_series( + cluster_size=size, + count_series=size_series[size], + frame_times_fs=selected_frame_times_fs, + observation_start_fs=preview.analysis_start_fs, + observation_stop_fs=preview.analysis_stop_fs, + ) + for size in sorted(size_series) + ) + + energy_data = ( + self.load_energy() if self.energy_file is not None else None + ) + return ClusterDynamicsResult( + preview=preview, + frame_results=tuple(frame_results), + bin_edges_fs=bin_edges_fs, + frames_per_bin=frames_per_bin, + total_clusters_per_bin=total_clusters_per_bin, + cluster_labels=cluster_labels, + cluster_sizes=dict(label_sizes), + raw_count_matrix=raw_count_matrix, + fraction_matrix=fraction_matrix, + mean_count_matrix=mean_count_matrix, + frame_count_matrix=frame_count_matrix, + total_clusters_per_frame=total_clusters_per_frame, + lifetime_by_label=lifetime_by_label, + lifetime_by_size=lifetime_by_size, + energy_data=energy_data, + ) + + def _build_network( + self, + *, + frame_format: str, + frame_path: Path, + resolved_box_dimensions: tuple[float, float, float] | None, + ) -> ClusterNetwork | XYZClusterNetwork: + if frame_format == "pdb": + structure = PDBStructure( + filepath=frame_path, + atom_type_definitions=self.atom_type_definitions, + source_name=frame_path.stem, + ) + return ClusterNetwork( + pdb_structure=structure, + atom_type_definitions=self.atom_type_definitions, + pair_cutoffs_def=self.pair_cutoff_definitions, + box_dimensions=resolved_box_dimensions, + default_cutoff=self.default_cutoff, + use_pbc=self.use_pbc, + include_shell_atoms_in_stoichiometry=( + self.include_shell_atoms_in_stoichiometry + ), + search_mode=self.search_mode, + ) + + structure = XYZStructure( + filepath=frame_path, + atom_type_definitions=self.atom_type_definitions, + source_name=frame_path.stem, + ) + return XYZClusterNetwork( + xyz_structure=structure, + atom_type_definitions=self.atom_type_definitions, + pair_cutoffs_def=self.pair_cutoff_definitions, + box_dimensions=resolved_box_dimensions, + default_cutoff=self.default_cutoff, + use_pbc=self.use_pbc, + include_shell_atoms_in_stoichiometry=( + self.include_shell_atoms_in_stoichiometry + ), + search_mode=self.search_mode, + ) + + def _summarize_label_series( + self, + *, + label: str, + cluster_size: int, + count_series: np.ndarray, + frame_times_fs: np.ndarray, + observation_start_fs: float, + observation_stop_fs: float, + ) -> ClusterLifetimeSummary: + metrics = _summarize_series_lifetimes( + count_series, + frame_times_fs=frame_times_fs, + observation_start_fs=observation_start_fs, + observation_stop_fs=observation_stop_fs, + ) + return ClusterLifetimeSummary( + label=label, + cluster_size=cluster_size, + total_observations=metrics.total_observations, + occupied_frames=metrics.occupied_frames, + mean_count_per_frame=metrics.mean_count_per_frame, + occupancy_fraction=metrics.occupancy_fraction, + association_events=metrics.association_events, + dissociation_events=metrics.dissociation_events, + association_rate_per_ps=metrics.association_rate_per_ps, + dissociation_rate_per_ps=metrics.dissociation_rate_per_ps, + completed_lifetime_count=len(metrics.completed_lifetimes_fs), + window_truncated_lifetime_count=len( + metrics.window_truncated_lifetimes_fs + ), + mean_lifetime_fs=metrics.mean_lifetime_fs, + std_lifetime_fs=metrics.std_lifetime_fs, + ) + + def _summarize_size_series( + self, + *, + cluster_size: int, + count_series: np.ndarray, + frame_times_fs: np.ndarray, + observation_start_fs: float, + observation_stop_fs: float, + ) -> ClusterSizeLifetimeSummary: + metrics = _summarize_series_lifetimes( + count_series, + frame_times_fs=frame_times_fs, + observation_start_fs=observation_start_fs, + observation_stop_fs=observation_stop_fs, + ) + return ClusterSizeLifetimeSummary( + cluster_size=cluster_size, + total_observations=metrics.total_observations, + occupied_frames=metrics.occupied_frames, + mean_count_per_frame=metrics.mean_count_per_frame, + occupancy_fraction=metrics.occupancy_fraction, + association_events=metrics.association_events, + dissociation_events=metrics.dissociation_events, + association_rate_per_ps=metrics.association_rate_per_ps, + dissociation_rate_per_ps=metrics.dissociation_rate_per_ps, + completed_lifetime_count=len(metrics.completed_lifetimes_fs), + window_truncated_lifetime_count=len( + metrics.window_truncated_lifetimes_fs + ), + mean_lifetime_fs=metrics.mean_lifetime_fs, + std_lifetime_fs=metrics.std_lifetime_fs, + ) + + +def _validate_positive_number(value: float, *, label: str) -> float: + normalized = float(value) + if normalized <= 0.0: + raise ValueError(f"{label} must be greater than zero.") + return normalized + + +def _validate_positive_integer(value: int | float, *, label: str) -> int: + if isinstance(value, bool): + raise ValueError(f"{label} must be a whole number greater than zero.") + normalized = float(value) + if normalized <= 0.0 or not normalized.is_integer(): + raise ValueError(f"{label} must be a whole number greater than zero.") + return int(normalized) + + +def _resolve_colormap_timestep_settings( + *, + frame_timestep_fs: float, + frames_per_colormap_timestep: int | float | None, + colormap_timestep_fs: float | None, + legacy_bin_size_fs: float | None, + require_integral_ratio: bool, +) -> tuple[int | None, float]: + resolved_colormap_timestep_fs = None + if colormap_timestep_fs is not None: + resolved_colormap_timestep_fs = _validate_positive_number( + colormap_timestep_fs, + label="Colormap timestep", + ) + elif legacy_bin_size_fs is not None: + resolved_colormap_timestep_fs = _validate_positive_number( + legacy_bin_size_fs, + label="Colormap timestep", + ) + + if frames_per_colormap_timestep is not None: + resolved_frames = _validate_positive_integer( + frames_per_colormap_timestep, + label="Frames per colormap timestep", + ) + expected_colormap_timestep_fs = float(frame_timestep_fs) * float( + resolved_frames + ) + if resolved_colormap_timestep_fs is None: + return resolved_frames, expected_colormap_timestep_fs + tolerance_fs = ( + max( + abs(expected_colormap_timestep_fs), + abs(resolved_colormap_timestep_fs), + 1.0, + ) + * 1.0e-9 + ) + if ( + abs(expected_colormap_timestep_fs - resolved_colormap_timestep_fs) + > tolerance_fs + ): + raise ValueError( + "Frames per colormap timestep does not match the resolved " + "colormap timestep." + ) + return resolved_frames, expected_colormap_timestep_fs + + if resolved_colormap_timestep_fs is None: + return 1, float(frame_timestep_fs) + + ratio = resolved_colormap_timestep_fs / float(frame_timestep_fs) + nearest_ratio = int(round(ratio)) + expected_colormap_timestep_fs = float(frame_timestep_fs) * float( + max(nearest_ratio, 1) + ) + tolerance_fs = ( + max( + abs(expected_colormap_timestep_fs), + abs(resolved_colormap_timestep_fs), + 1.0, + ) + * 1.0e-9 + ) + if ( + abs(expected_colormap_timestep_fs - resolved_colormap_timestep_fs) + <= tolerance_fs + ): + return max(nearest_ratio, 1), expected_colormap_timestep_fs + if require_integral_ratio: + raise ValueError( + "Colormap timestep must be an integer multiple of the frame " + "timestep." + ) + return None, resolved_colormap_timestep_fs + + +def _time_scale(unit: TimeUnit) -> float: + if unit == "ps": + return 1000.0 + return 1.0 + + +def _frame_times( + total_frames: int, + *, + first_frame_time_fs: float, + frame_timestep_fs: float, +) -> np.ndarray: + return first_frame_time_fs + ( + np.arange(max(int(total_frames), 0), dtype=float) * frame_timestep_fs + ) + + +def _infer_frame_time_axis( + frames_dir: Path, + frame_paths: tuple[Path, ...] | list[Path], + *, + frame_timestep_fs: float, + folder_start_time_fs: float | None, + fallback_first_frame_time_fs: float, +) -> _FrameTimeAxis: + metadata_payload = _load_mdtrajectory_export_metadata(frames_dir) + resolved_folder_start_time_fs, folder_start_source = ( + _resolve_folder_start_time( + frames_dir=frames_dir, + explicit_folder_start_time_fs=folder_start_time_fs, + metadata_payload=metadata_payload, + ) + ) + warnings: list[str] = [] + + metadata_time_axis = _frame_times_from_mdtrajectory_metadata( + metadata_payload, + frame_paths, + ) + if metadata_time_axis is not None: + frame_times_fs, source_frame_indices = metadata_time_axis + time_source_label = "mdtrajectory export metadata" + else: + if metadata_payload is not None: + warnings.append( + "Found mdtrajectory export metadata, but it did not provide " + "usable times for every extracted frame. Falling back to the " + "frame filenames and timestep." + ) + source_frame_indices = tuple( + _parse_frame_filename_index(path.name) for path in frame_paths + ) + if source_frame_indices and all( + index is not None for index in source_frame_indices + ): + frame_times_fs = np.asarray( + [ + float(index) * float(frame_timestep_fs) + for index in source_frame_indices + ], + dtype=float, + ) + time_source_label = "Frame filenames x timestep" + else: + sequential_start_time_fs = ( + resolved_folder_start_time_fs + if resolved_folder_start_time_fs is not None + else float(fallback_first_frame_time_fs) + ) + frame_times_fs = _frame_times( + len(frame_paths), + first_frame_time_fs=sequential_start_time_fs, + frame_timestep_fs=frame_timestep_fs, + ) + time_source_label = "Sequential frames from start time" + if resolved_folder_start_time_fs is None: + warnings.append( + "Start/cutoff time metadata was not found in the folder " + "name or mdtrajectory export metadata. Sequential frame " + "times are being generated from the fallback start time." + ) + else: + warnings.append( + "Using the folder/start time as the first extracted " + "frame time because the frame filenames do not expose " + "their original source-frame indices." + ) + + if frame_times_fs.size and resolved_folder_start_time_fs is not None: + first_resolved_time_fs = float(frame_times_fs[0]) + mismatch_tolerance_fs = max(float(frame_timestep_fs) * 0.5, 1.0e-9) + if ( + abs(first_resolved_time_fs - resolved_folder_start_time_fs) + > mismatch_tolerance_fs + ): + warnings.append( + "Folder cutoff/start time is " + f"{resolved_folder_start_time_fs:.3f} fs, but the first " + "resolved extracted-frame time is " + f"{first_resolved_time_fs:.3f} fs. The analysis uses the " + "resolved frame times for plotting and kinetics." + ) + + return _FrameTimeAxis( + frame_times_fs=frame_times_fs, + source_frame_indices=source_frame_indices, + time_source_label=time_source_label, + folder_start_time_fs=resolved_folder_start_time_fs, + folder_start_time_source=folder_start_source, + warnings=tuple(warnings), + ) + + +def _resolve_folder_start_time( + *, + frames_dir: Path, + explicit_folder_start_time_fs: float | None, + metadata_payload: dict[str, object] | None, +) -> tuple[float | None, str | None]: + if explicit_folder_start_time_fs is not None: + return float(explicit_folder_start_time_fs), "manual field" + + if metadata_payload is not None: + selection_payload = metadata_payload.get("selection") + if isinstance(selection_payload, dict): + applied_cutoff_fs = selection_payload.get("applied_cutoff_fs") + if applied_cutoff_fs is not None: + return float(applied_cutoff_fs), "mdtrajectory export metadata" + + folder_start_time_fs = _parse_folder_start_time_from_name(frames_dir.name) + if folder_start_time_fs is not None: + return folder_start_time_fs, "folder name" + return None, None + + +def _load_mdtrajectory_export_metadata( + frames_dir: Path, +) -> dict[str, object] | None: + metadata_path = frames_dir / EXPORT_METADATA_FILENAME + if not metadata_path.is_file(): + return None + try: + payload = json.loads(metadata_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError, UnicodeDecodeError): + return None + return payload if isinstance(payload, dict) else None + + +def _frame_times_from_mdtrajectory_metadata( + metadata_payload: dict[str, object] | None, + frame_paths: tuple[Path, ...] | list[Path], +) -> tuple[np.ndarray, tuple[int | None, ...]] | None: + if metadata_payload is None: + return None + + written_frames = metadata_payload.get("written_frames") + if not isinstance(written_frames, list): + return None + + mapping: dict[str, tuple[float, int | None]] = {} + for entry in written_frames: + if not isinstance(entry, dict): + continue + filename = entry.get("filename") + time_fs = entry.get("time_fs") + if not isinstance(filename, str) or time_fs is None: + continue + frame_index = entry.get("frame_index") + mapping[filename] = ( + float(time_fs), + None if frame_index is None else int(frame_index), + ) + + if not mapping: + return None + + resolved_times: list[float] = [] + resolved_indices: list[int | None] = [] + for frame_path in frame_paths: + payload = mapping.get(frame_path.name) + if payload is None: + return None + time_fs, frame_index = payload + resolved_times.append(float(time_fs)) + resolved_indices.append(frame_index) + + return np.asarray(resolved_times, dtype=float), tuple(resolved_indices) + + +def _parse_frame_filename_index(filename: str) -> int | None: + match = _FRAME_FILENAME_PATTERN.match(filename.strip()) + if match is None: + return None + return int(match.group("index")) + + +def _parse_folder_start_time_from_name(folder_name: str) -> float | None: + match = _FOLDER_START_TIME_PATTERN.search(folder_name.strip()) + if match is None: + return None + return float(match.group("value").replace("_", ".")) + + +def _resolve_time_window( + *, + frame_times_fs: np.ndarray, + analysis_start_fs: float | None, + analysis_stop_fs: float | None, + frame_timestep_fs: float, +) -> tuple[float, float, tuple[int, ...]]: + if frame_times_fs.size == 0: + start = 0.0 if analysis_start_fs is None else float(analysis_start_fs) + stop = start + frame_timestep_fs + return start, stop, () + + requested_start = ( + float(frame_times_fs[0]) + if analysis_start_fs is None + else float(analysis_start_fs) + ) + requested_stop = ( + float(frame_times_fs[-1]) + float(frame_timestep_fs) + if analysis_stop_fs is None + else float(analysis_stop_fs) + ) + if requested_stop <= requested_start: + requested_stop = requested_start + float(frame_timestep_fs) + + selected_indices = tuple( + int(index) + for index, time_fs in enumerate(frame_times_fs) + if float(time_fs) >= requested_start + and ( + float(time_fs) < requested_stop + if analysis_stop_fs is None + else float(time_fs) <= requested_stop + ) + ) + return requested_start, requested_stop, selected_indices + + +def _build_bin_edges( + analysis_start_fs: float, + analysis_stop_fs: float, + colormap_timestep_fs: float, +) -> np.ndarray: + start = float(analysis_start_fs) + stop = float(analysis_stop_fs) + if stop <= start: + stop = start + float(colormap_timestep_fs) + n_bins = max( + int(np.ceil((stop - start) / float(colormap_timestep_fs))), + 1, + ) + edges = start + ( + np.arange(n_bins + 1, dtype=float) * float(colormap_timestep_fs) + ) + if edges[-1] < stop: + edges = np.append(edges, stop) + else: + edges[-1] = max(edges[-1], stop) + return edges + + +def _assign_bins( + frame_times_fs: np.ndarray, bin_edges_fs: np.ndarray +) -> np.ndarray: + bin_indices = ( + np.searchsorted(bin_edges_fs, frame_times_fs, side="right") - 1 + ) + return np.clip(bin_indices, 0, len(bin_edges_fs) - 2) + + +def _safe_divide(numerator: np.ndarray, denominator: np.ndarray) -> np.ndarray: + return np.divide( + numerator, + denominator, + out=np.zeros_like(numerator, dtype=float), + where=denominator > 0.0, + ) + + +def _summarize_series_lifetimes( + count_series: np.ndarray, + *, + frame_times_fs: np.ndarray, + observation_start_fs: float, + observation_stop_fs: float, +) -> _SeriesLifetimeMetrics: + normalized = np.asarray(count_series, dtype=int) + resolved_frame_times_fs = np.asarray(frame_times_fs, dtype=float) + total_frames = len(normalized) + if total_frames != len(resolved_frame_times_fs): + raise ValueError( + "Count-series length does not match the resolved frame-time axis." + ) + total_observations = int(normalized.sum()) + occupied_frames = int(np.count_nonzero(normalized > 0)) + mean_count_per_frame = float(normalized.mean()) if total_frames else 0.0 + occupancy_fraction = ( + float(occupied_frames / total_frames) if total_frames else 0.0 + ) + association_events = 0 + dissociation_events = 0 + active_instances: deque[tuple[float, bool]] = deque() + completed_lifetimes_fs: list[float] = [] + window_truncated_lifetimes_fs: list[float] = [] + + if total_frames: + for _ in range(int(normalized[0])): + active_instances.append((float(resolved_frame_times_fs[0]), True)) + + for frame_index in range(1, total_frames): + previous_count = int(normalized[frame_index - 1]) + current_count = int(normalized[frame_index]) + current_time_fs = float(resolved_frame_times_fs[frame_index]) + if current_count > previous_count: + births = current_count - previous_count + association_events += births + for _ in range(births): + active_instances.append((current_time_fs, False)) + elif current_count < previous_count: + deaths = previous_count - current_count + dissociation_events += deaths + for _ in range(deaths): + if not active_instances: + break + start_time_fs, left_censored = active_instances.popleft() + lifetime_fs = max(current_time_fs - start_time_fs, 0.0) + if left_censored: + window_truncated_lifetimes_fs.append(lifetime_fs) + else: + completed_lifetimes_fs.append(lifetime_fs) + + if total_frames: + observation_window_stop_fs = max( + float(observation_stop_fs), + float(resolved_frame_times_fs[-1]), + ) + else: + observation_window_stop_fs = float(observation_stop_fs) + + for start_time_fs, _left_censored in active_instances: + lifetime_fs = max(observation_window_stop_fs - start_time_fs, 0.0) + window_truncated_lifetimes_fs.append(lifetime_fs) + + completed_array = np.asarray(completed_lifetimes_fs, dtype=float) + if completed_array.size: + mean_lifetime_fs = float(completed_array.mean()) + std_lifetime_fs = float(completed_array.std(ddof=0)) + else: + mean_lifetime_fs = None + std_lifetime_fs = None + + observation_time_ps = max( + (float(observation_window_stop_fs) - float(observation_start_fs)) + / 1000.0, + 1e-12, + ) + return _SeriesLifetimeMetrics( + total_observations=total_observations, + occupied_frames=occupied_frames, + mean_count_per_frame=mean_count_per_frame, + occupancy_fraction=occupancy_fraction, + association_events=association_events, + dissociation_events=dissociation_events, + association_rate_per_ps=float( + association_events / observation_time_ps + ), + dissociation_rate_per_ps=float( + dissociation_events / observation_time_ps + ), + completed_lifetimes_fs=tuple(completed_lifetimes_fs), + window_truncated_lifetimes_fs=tuple(window_truncated_lifetimes_fs), + mean_lifetime_fs=mean_lifetime_fs, + std_lifetime_fs=std_lifetime_fs, + ) + + +__all__ = [ + "ClusterDynamicsResult", + "ClusterDynamicsSelectionPreview", + "ClusterDynamicsWorkflow", + "ClusterLifetimeSummary", + "ClusterSizeLifetimeSummary", +] diff --git a/src/saxshell/clusterdynamicsml/__init__.py b/src/saxshell/clusterdynamicsml/__init__.py new file mode 100644 index 0000000..e95ca62 --- /dev/null +++ b/src/saxshell/clusterdynamicsml/__init__.py @@ -0,0 +1,33 @@ +"""Experimental larger-cluster surrogate prediction tools.""" + +from .dataset import ( + LoadedClusterDynamicsMLDataset, + SavedClusterDynamicsMLDataset, + load_cluster_dynamicsai_dataset, + save_cluster_dynamicsai_dataset, +) +from .workflow import ( + ClusterDynamicsMLPreview, + ClusterDynamicsMLResult, + ClusterDynamicsMLSAXSComparison, + ClusterDynamicsMLTrainingObservation, + ClusterDynamicsMLWorkflow, + ClusterStructureObservation, + PredictedClusterCandidate, + SAXSComponentWeight, +) + +__all__ = [ + "ClusterDynamicsMLResult", + "ClusterDynamicsMLPreview", + "ClusterDynamicsMLSAXSComparison", + "ClusterDynamicsMLTrainingObservation", + "ClusterDynamicsMLWorkflow", + "ClusterStructureObservation", + "PredictedClusterCandidate", + "SAXSComponentWeight", + "LoadedClusterDynamicsMLDataset", + "SavedClusterDynamicsMLDataset", + "load_cluster_dynamicsai_dataset", + "save_cluster_dynamicsai_dataset", +] diff --git a/src/saxshell/clusterdynamicsml/__main__.py b/src/saxshell/clusterdynamicsml/__main__.py new file mode 100644 index 0000000..bfdcd0c --- /dev/null +++ b/src/saxshell/clusterdynamicsml/__main__.py @@ -0,0 +1,4 @@ +from .cli import main + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/saxshell/clusterdynamicsml/cli.py b/src/saxshell/clusterdynamicsml/cli.py new file mode 100644 index 0000000..9bf2e4f --- /dev/null +++ b/src/saxshell/clusterdynamicsml/cli.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .ui.main_window import main + +__all__ = ["main"] diff --git a/src/saxshell/clusterdynamicsml/dataset.py b/src/saxshell/clusterdynamicsml/dataset.py new file mode 100644 index 0000000..2123053 --- /dev/null +++ b/src/saxshell/clusterdynamicsml/dataset.py @@ -0,0 +1,1273 @@ +from __future__ import annotations + +import csv +import json +import shutil +from dataclasses import asdict, dataclass +from pathlib import Path + +import numpy as np + +from saxshell.cluster import FrameClusterResult +from saxshell.clusterdynamics.dataset import ( + export_cluster_dynamics_colormap_csv, + export_cluster_dynamics_lifetime_csv, +) +from saxshell.clusterdynamics.workflow import ( + ClusterDynamicsResult, + ClusterDynamicsSelectionPreview, + ClusterLifetimeSummary, + ClusterSizeLifetimeSummary, + _resolve_colormap_timestep_settings, +) +from saxshell.mdtrajectory.frame.cp2k_ener import CP2KEnergyData + +from .workflow import ( + ClusterDynamicsMLPreview, + ClusterDynamicsMLResult, + ClusterDynamicsMLSAXSComparison, + ClusterDynamicsMLTrainingObservation, + ClusterStructureObservation, + PredictedClusterCandidate, + SAXSComponentWeight, + _resolved_population_weights, +) + +DATASET_VERSION = 1 + + +@dataclass(slots=True) +class SavedClusterDynamicsMLDataset: + dataset_file: Path + written_files: tuple[Path, ...] + + +@dataclass(slots=True) +class LoadedClusterDynamicsMLDataset: + dataset_file: Path + result: ClusterDynamicsMLResult + analysis_settings: dict[str, object] + + +def save_cluster_dynamicsai_dataset( + result: ClusterDynamicsMLResult, + output_file: str | Path, + *, + analysis_settings: dict[str, object] | None = None, +) -> SavedClusterDynamicsMLDataset: + dataset_file = Path(output_file).expanduser().resolve() + if dataset_file.suffix.lower() != ".json": + dataset_file = dataset_file.with_suffix(".json") + dataset_file.parent.mkdir(parents=True, exist_ok=True) + payload = { + "version": DATASET_VERSION, + "analysis_settings": analysis_settings or {}, + "dynamics_result": _serialize_cluster_dynamics_result( + result.dynamics_result + ), + "preview": _serialize_preview(result.preview), + "structure_observations": [ + _serialize_structure_observation(entry) + for entry in result.structure_observations + ], + "training_observations": [ + _serialize_training_observation(entry) + for entry in result.training_observations + ], + "predictions": [ + _serialize_prediction_candidate(entry) + for entry in result.predictions + ], + "saxs_comparison": _serialize_saxs_comparison(result.saxs_comparison), + "max_observed_node_count": int(result.max_observed_node_count), + "max_predicted_node_count": ( + None + if result.max_predicted_node_count is None + else int(result.max_predicted_node_count) + ), + "prediction_population_share_threshold": float( + result.prediction_population_share_threshold + ), + } + dataset_file.write_text( + json.dumps(payload, indent=2) + "\n", + encoding="utf-8", + ) + + written_files = [dataset_file] + written_files.append( + export_cluster_dynamics_colormap_csv( + result.dynamics_result, + dataset_file.with_name( + f"{dataset_file.stem}_cluster_distribution.csv" + ), + ) + ) + written_files.append( + export_cluster_dynamics_lifetime_csv( + result.dynamics_result, + dataset_file.with_name(f"{dataset_file.stem}_lifetime.csv"), + ) + ) + written_files.append(_write_prediction_csv(result, dataset_file)) + written_files.extend(_write_histogram_csvs(result, dataset_file)) + if result.saxs_comparison is not None: + written_files.append( + _write_saxs_csv(result.saxs_comparison, dataset_file) + ) + written_files.extend( + _write_saxs_component_profiles( + result.saxs_comparison, dataset_file + ) + ) + written_files.extend( + _write_prediction_structures(result.predictions, dataset_file) + ) + return SavedClusterDynamicsMLDataset( + dataset_file=dataset_file, + written_files=tuple(written_files), + ) + + +def load_cluster_dynamicsai_dataset( + dataset_file: str | Path, +) -> LoadedClusterDynamicsMLDataset: + resolved_file = Path(dataset_file).expanduser().resolve() + payload = json.loads(resolved_file.read_text(encoding="utf-8")) + if int(payload.get("version", 0)) != DATASET_VERSION: + raise ValueError( + "This clusterdynamicsml dataset uses an unsupported format version." + ) + result = ClusterDynamicsMLResult( + dynamics_result=_deserialize_cluster_dynamics_result( + payload.get("dynamics_result", {}), + fallback_path=resolved_file, + ), + preview=_deserialize_preview(payload.get("preview", {})), + structure_observations=tuple( + _deserialize_structure_observation(entry) + for entry in payload.get("structure_observations", []) + if isinstance(entry, dict) + ), + training_observations=tuple( + _deserialize_training_observation(entry) + for entry in payload.get("training_observations", []) + if isinstance(entry, dict) + ), + predictions=tuple( + _deserialize_prediction_candidate(entry) + for entry in payload.get("predictions", []) + if isinstance(entry, dict) + ), + saxs_comparison=_deserialize_saxs_comparison( + payload.get("saxs_comparison"), + fallback_path=resolved_file, + ), + max_observed_node_count=int(payload.get("max_observed_node_count", 0)), + max_predicted_node_count=_optional_int( + payload.get("max_predicted_node_count") + ), + prediction_population_share_threshold=float( + payload.get("prediction_population_share_threshold", 0.02) + ), + ) + return LoadedClusterDynamicsMLDataset( + dataset_file=resolved_file, + result=result, + analysis_settings=dict(payload.get("analysis_settings", {})), + ) + + +def _serialize_preview(preview: ClusterDynamicsMLPreview) -> dict[str, object]: + return { + "dynamics_preview": preview.dynamics_preview.to_dict(), + "preview_summary": dict(preview.dynamics_preview.summary), + "selected_frame_indices": list( + preview.dynamics_preview.selected_frame_indices + ), + "selected_frame_names": list( + preview.dynamics_preview.selected_frame_names + ), + "selected_source_frame_indices": list( + preview.dynamics_preview.selected_source_frame_indices + ), + "clusters_dir": ( + None if preview.clusters_dir is None else str(preview.clusters_dir) + ), + "project_dir": ( + None if preview.project_dir is None else str(preview.project_dir) + ), + "experimental_data_path": ( + None + if preview.experimental_data_path is None + else str(preview.experimental_data_path) + ), + "structure_label_count": int(preview.structure_label_count), + "total_structure_files": int(preview.total_structure_files), + "observed_node_counts": list(preview.observed_node_counts), + "target_node_counts": list(preview.target_node_counts), + "warnings": list(preview.warnings), + } + + +def _deserialize_preview(payload: object) -> ClusterDynamicsMLPreview: + preview_payload = dict(payload if isinstance(payload, dict) else {}) + dynamics_preview_payload = dict( + preview_payload.get("dynamics_preview", {}) + ) + preview_summary = dict(preview_payload.get("preview_summary", {})) + frame_timestep_fs = float( + dynamics_preview_payload.get("frame_timestep_fs", 0.5) + ) + ( + frames_per_colormap_timestep, + colormap_timestep_fs, + ) = _resolve_colormap_timestep_settings( + frame_timestep_fs=frame_timestep_fs, + frames_per_colormap_timestep=dynamics_preview_payload.get( + "frames_per_colormap_timestep" + ), + colormap_timestep_fs=_optional_float( + dynamics_preview_payload.get("colormap_timestep_fs") + ), + legacy_bin_size_fs=_optional_float( + dynamics_preview_payload.get("bin_size_fs") + ), + require_integral_ratio=False, + ) + dynamics_preview = ClusterDynamicsSelectionPreview( + summary=preview_summary, + frame_format=str(dynamics_preview_payload.get("frame_format", "xyz")), + resolved_box_dimensions=_coerce_box_dimensions( + dynamics_preview_payload.get("resolved_box_dimensions") + ), + use_pbc=bool(dynamics_preview_payload.get("use_pbc", False)), + first_frame_time_fs=float( + dynamics_preview_payload.get("first_frame_time_fs", 0.0) + ), + frame_timestep_fs=frame_timestep_fs, + frames_per_colormap_timestep=frames_per_colormap_timestep, + colormap_timestep_fs=colormap_timestep_fs, + analysis_start_fs=float( + dynamics_preview_payload.get("analysis_start_fs", 0.0) + ), + analysis_stop_fs=float( + dynamics_preview_payload.get("analysis_stop_fs", 0.0) + ), + first_selected_time_fs=_optional_float( + dynamics_preview_payload.get("first_selected_time_fs") + ), + last_selected_time_fs=_optional_float( + dynamics_preview_payload.get("last_selected_time_fs") + ), + selected_frame_indices=tuple( + int(value) + for value in preview_payload.get("selected_frame_indices", []) + ), + selected_frame_names=tuple( + str(value) + for value in preview_payload.get("selected_frame_names", []) + ), + selected_source_frame_indices=tuple( + None if value is None else int(value) + for value in preview_payload.get( + "selected_source_frame_indices", + [], + ) + ), + energy_file=( + None + if dynamics_preview_payload.get("energy_file") is None + else Path(str(dynamics_preview_payload.get("energy_file"))) + ), + folder_start_time_fs=_optional_float( + dynamics_preview_payload.get("folder_start_time_fs") + ), + folder_start_time_source=_optional_str( + dynamics_preview_payload.get("folder_start_time_source") + ), + time_source_label=str( + dynamics_preview_payload.get("time_source_label", "Saved dataset") + ), + time_warnings=tuple( + str(value) + for value in dynamics_preview_payload.get("time_warnings", []) + ), + ) + return ClusterDynamicsMLPreview( + dynamics_preview=dynamics_preview, + clusters_dir=_optional_path(preview_payload.get("clusters_dir")), + project_dir=_optional_path(preview_payload.get("project_dir")), + experimental_data_path=_optional_path( + preview_payload.get("experimental_data_path") + ), + structure_label_count=int( + preview_payload.get("structure_label_count", 0) + ), + total_structure_files=int( + preview_payload.get("total_structure_files", 0) + ), + observed_node_counts=tuple( + int(value) + for value in preview_payload.get("observed_node_counts", []) + ), + target_node_counts=tuple( + int(value) + for value in preview_payload.get("target_node_counts", []) + ), + warnings=tuple( + str(value) for value in preview_payload.get("warnings", []) + ), + ) + + +def _serialize_structure_observation( + observation: ClusterStructureObservation, +) -> dict[str, object]: + return { + "label": observation.label, + "node_count": int(observation.node_count), + "element_counts": dict(observation.element_counts), + "file_count": int(observation.file_count), + "representative_path": ( + None + if observation.representative_path is None + else str(observation.representative_path) + ), + "structure_dir": str(observation.structure_dir), + "motifs": list(observation.motifs), + "mean_atom_count": float(observation.mean_atom_count), + "mean_radius_of_gyration": float(observation.mean_radius_of_gyration), + "mean_max_radius": float(observation.mean_max_radius), + "mean_semiaxis_a": float(observation.mean_semiaxis_a), + "mean_semiaxis_b": float(observation.mean_semiaxis_b), + "mean_semiaxis_c": float(observation.mean_semiaxis_c), + } + + +def _deserialize_structure_observation( + payload: dict[str, object], +) -> ClusterStructureObservation: + return ClusterStructureObservation( + label=str(payload.get("label", "")), + node_count=int(payload.get("node_count", 0)), + element_counts={ + str(element): int(count) + for element, count in dict( + payload.get("element_counts", {}) + ).items() + }, + file_count=int(payload.get("file_count", 0)), + representative_path=_optional_path(payload.get("representative_path")), + structure_dir=Path(str(payload.get("structure_dir", "."))), + motifs=tuple(str(value) for value in payload.get("motifs", [])), + mean_atom_count=float(payload.get("mean_atom_count", 0.0)), + mean_radius_of_gyration=float( + payload.get("mean_radius_of_gyration", 0.0) + ), + mean_max_radius=float(payload.get("mean_max_radius", 0.0)), + mean_semiaxis_a=float(payload.get("mean_semiaxis_a", 0.0)), + mean_semiaxis_b=float(payload.get("mean_semiaxis_b", 0.0)), + mean_semiaxis_c=float(payload.get("mean_semiaxis_c", 0.0)), + ) + + +def _serialize_training_observation( + observation: ClusterDynamicsMLTrainingObservation, +) -> dict[str, object]: + return { + **_serialize_structure_observation( + ClusterStructureObservation( + label=observation.label, + node_count=observation.node_count, + element_counts=observation.element_counts, + file_count=observation.file_count, + representative_path=observation.representative_path, + structure_dir=observation.structure_dir, + motifs=observation.motifs, + mean_atom_count=observation.mean_atom_count, + mean_radius_of_gyration=observation.mean_radius_of_gyration, + mean_max_radius=observation.mean_max_radius, + mean_semiaxis_a=observation.mean_semiaxis_a, + mean_semiaxis_b=observation.mean_semiaxis_b, + mean_semiaxis_c=observation.mean_semiaxis_c, + ) + ), + "cluster_size": int(observation.cluster_size), + "total_observations": int(observation.total_observations), + "occupied_frames": int(observation.occupied_frames), + "mean_count_per_frame": float(observation.mean_count_per_frame), + "occupancy_fraction": float(observation.occupancy_fraction), + "association_events": int(observation.association_events), + "dissociation_events": int(observation.dissociation_events), + "association_rate_per_ps": float(observation.association_rate_per_ps), + "dissociation_rate_per_ps": float( + observation.dissociation_rate_per_ps + ), + "completed_lifetime_count": int(observation.completed_lifetime_count), + "window_truncated_lifetime_count": int( + observation.window_truncated_lifetime_count + ), + "mean_lifetime_fs": _optional_float(observation.mean_lifetime_fs), + "std_lifetime_fs": _optional_float(observation.std_lifetime_fs), + } + + +def _deserialize_training_observation( + payload: dict[str, object], +) -> ClusterDynamicsMLTrainingObservation: + structure_observation = _deserialize_structure_observation(payload) + return ClusterDynamicsMLTrainingObservation( + label=structure_observation.label, + node_count=structure_observation.node_count, + cluster_size=int(payload.get("cluster_size", 0)), + element_counts=structure_observation.element_counts, + file_count=structure_observation.file_count, + representative_path=structure_observation.representative_path, + structure_dir=structure_observation.structure_dir, + motifs=structure_observation.motifs, + mean_atom_count=structure_observation.mean_atom_count, + mean_radius_of_gyration=structure_observation.mean_radius_of_gyration, + mean_max_radius=structure_observation.mean_max_radius, + mean_semiaxis_a=structure_observation.mean_semiaxis_a, + mean_semiaxis_b=structure_observation.mean_semiaxis_b, + mean_semiaxis_c=structure_observation.mean_semiaxis_c, + total_observations=int(payload.get("total_observations", 0)), + occupied_frames=int(payload.get("occupied_frames", 0)), + mean_count_per_frame=float(payload.get("mean_count_per_frame", 0.0)), + occupancy_fraction=float(payload.get("occupancy_fraction", 0.0)), + association_events=int(payload.get("association_events", 0)), + dissociation_events=int(payload.get("dissociation_events", 0)), + association_rate_per_ps=float( + payload.get("association_rate_per_ps", 0.0) + ), + dissociation_rate_per_ps=float( + payload.get("dissociation_rate_per_ps", 0.0) + ), + completed_lifetime_count=int( + payload.get("completed_lifetime_count", 0) + ), + window_truncated_lifetime_count=int( + payload.get("window_truncated_lifetime_count", 0) + ), + mean_lifetime_fs=_optional_float(payload.get("mean_lifetime_fs")), + std_lifetime_fs=_optional_float(payload.get("std_lifetime_fs")), + ) + + +def _serialize_prediction_candidate( + candidate: PredictedClusterCandidate, +) -> dict[str, object]: + return { + "target_node_count": int(candidate.target_node_count), + "rank": int(candidate.rank), + "label": candidate.label, + "element_counts": dict(candidate.element_counts), + "predicted_mean_count_per_frame": float( + candidate.predicted_mean_count_per_frame + ), + "predicted_occupancy_fraction": float( + candidate.predicted_occupancy_fraction + ), + "predicted_mean_lifetime_fs": float( + candidate.predicted_mean_lifetime_fs + ), + "predicted_association_rate_per_ps": float( + candidate.predicted_association_rate_per_ps + ), + "predicted_dissociation_rate_per_ps": float( + candidate.predicted_dissociation_rate_per_ps + ), + "predicted_mean_radius_of_gyration": float( + candidate.predicted_mean_radius_of_gyration + ), + "predicted_mean_max_radius": float( + candidate.predicted_mean_max_radius + ), + "predicted_mean_semiaxis_a": float( + candidate.predicted_mean_semiaxis_a + ), + "predicted_mean_semiaxis_b": float( + candidate.predicted_mean_semiaxis_b + ), + "predicted_mean_semiaxis_c": float( + candidate.predicted_mean_semiaxis_c + ), + "predicted_population_share": float( + candidate.predicted_population_share + ), + "predicted_stability_score": float( + candidate.predicted_stability_score + ), + "source_label": candidate.source_label, + "notes": candidate.notes, + "generated_elements": list(candidate.generated_elements), + "generated_coordinates": candidate.generated_coordinates.tolist(), + } + + +def _deserialize_prediction_candidate( + payload: dict[str, object], +) -> PredictedClusterCandidate: + return PredictedClusterCandidate( + target_node_count=int(payload.get("target_node_count", 0)), + rank=int(payload.get("rank", 0)), + label=str(payload.get("label", "")), + element_counts={ + str(element): int(count) + for element, count in dict( + payload.get("element_counts", {}) + ).items() + }, + predicted_mean_count_per_frame=float( + payload.get("predicted_mean_count_per_frame", 0.0) + ), + predicted_occupancy_fraction=float( + payload.get("predicted_occupancy_fraction", 0.0) + ), + predicted_mean_lifetime_fs=float( + payload.get("predicted_mean_lifetime_fs", 0.0) + ), + predicted_association_rate_per_ps=float( + payload.get("predicted_association_rate_per_ps", 0.0) + ), + predicted_dissociation_rate_per_ps=float( + payload.get("predicted_dissociation_rate_per_ps", 0.0) + ), + predicted_mean_radius_of_gyration=float( + payload.get("predicted_mean_radius_of_gyration", 0.0) + ), + predicted_mean_max_radius=float( + payload.get("predicted_mean_max_radius", 0.0) + ), + predicted_mean_semiaxis_a=float( + payload.get("predicted_mean_semiaxis_a", 0.0) + ), + predicted_mean_semiaxis_b=float( + payload.get("predicted_mean_semiaxis_b", 0.0) + ), + predicted_mean_semiaxis_c=float( + payload.get("predicted_mean_semiaxis_c", 0.0) + ), + predicted_population_share=float( + payload.get("predicted_population_share", 0.0) + ), + predicted_stability_score=float( + payload.get("predicted_stability_score", 0.0) + ), + source_label=_optional_str(payload.get("source_label")), + notes=str(payload.get("notes", "")), + generated_elements=tuple( + str(value) for value in payload.get("generated_elements", []) + ), + generated_coordinates=np.asarray( + payload.get("generated_coordinates", []), + dtype=float, + ), + ) + + +def _serialize_saxs_comparison( + comparison: ClusterDynamicsMLSAXSComparison | None, +) -> dict[str, object] | None: + if comparison is None: + return None + return { + "q_values": comparison.q_values.tolist(), + "observed_raw_model_intensity": ( + None + if comparison.observed_raw_model_intensity is None + else comparison.observed_raw_model_intensity.tolist() + ), + "observed_fitted_model_intensity": ( + None + if comparison.observed_fitted_model_intensity is None + else comparison.observed_fitted_model_intensity.tolist() + ), + "observed_rmse": _optional_float(comparison.observed_rmse), + "raw_model_intensity": comparison.raw_model_intensity.tolist(), + "fitted_model_intensity": comparison.fitted_model_intensity.tolist(), + "experimental_intensity": ( + None + if comparison.experimental_intensity is None + else comparison.experimental_intensity.tolist() + ), + "residuals": ( + None + if comparison.residuals is None + else comparison.residuals.tolist() + ), + "scale_factor": float(comparison.scale_factor), + "offset": float(comparison.offset), + "rmse": _optional_float(comparison.rmse), + "component_weights": [ + { + "label": entry.label, + "weight": float(entry.weight), + "source": entry.source, + "profile_path": ( + None + if entry.profile_path is None + else str(entry.profile_path) + ), + "structure_path": ( + None + if entry.structure_path is None + else str(entry.structure_path) + ), + } + for entry in comparison.component_weights + ], + "experimental_data_path": ( + None + if comparison.experimental_data_path is None + else str(comparison.experimental_data_path) + ), + "component_output_dir": ( + None + if comparison.component_output_dir is None + else str(comparison.component_output_dir) + ), + "surrogate_structure_dir": ( + None + if comparison.surrogate_structure_dir is None + else str(comparison.surrogate_structure_dir) + ), + } + + +def _deserialize_saxs_comparison( + payload: object, + *, + fallback_path: Path, +) -> ClusterDynamicsMLSAXSComparison | None: + if not isinstance(payload, dict): + return None + del fallback_path + experimental_intensity = payload.get("experimental_intensity") + observed_raw_model_intensity = payload.get("observed_raw_model_intensity") + observed_fitted_model_intensity = payload.get( + "observed_fitted_model_intensity" + ) + residuals = payload.get("residuals") + return ClusterDynamicsMLSAXSComparison( + q_values=np.asarray(payload.get("q_values", []), dtype=float), + observed_raw_model_intensity=( + None + if observed_raw_model_intensity is None + else np.asarray(observed_raw_model_intensity, dtype=float) + ), + observed_fitted_model_intensity=( + None + if observed_fitted_model_intensity is None + else np.asarray(observed_fitted_model_intensity, dtype=float) + ), + observed_rmse=_optional_float(payload.get("observed_rmse")), + raw_model_intensity=np.asarray( + payload.get("raw_model_intensity", []), + dtype=float, + ), + fitted_model_intensity=np.asarray( + payload.get("fitted_model_intensity", []), + dtype=float, + ), + experimental_intensity=( + None + if experimental_intensity is None + else np.asarray(experimental_intensity, dtype=float) + ), + residuals=( + None if residuals is None else np.asarray(residuals, dtype=float) + ), + scale_factor=float(payload.get("scale_factor", 1.0)), + offset=float(payload.get("offset", 0.0)), + rmse=_optional_float(payload.get("rmse")), + component_weights=tuple( + SAXSComponentWeight( + label=str(entry.get("label", "")), + weight=float(entry.get("weight", 0.0)), + source=str(entry.get("source", "")), + profile_path=_optional_path(entry.get("profile_path")), + structure_path=_optional_path(entry.get("structure_path")), + ) + for entry in payload.get("component_weights", []) + if isinstance(entry, dict) + ), + experimental_data_path=_optional_path( + payload.get("experimental_data_path") + ), + component_output_dir=_optional_path( + payload.get("component_output_dir") + ), + surrogate_structure_dir=_optional_path( + payload.get("surrogate_structure_dir") + ), + ) + + +def _serialize_cluster_dynamics_result( + result: ClusterDynamicsResult, +) -> dict[str, object]: + return { + "preview_summary": dict(result.preview.summary), + "preview": result.preview.to_dict(), + "selected_frame_indices": list(result.preview.selected_frame_indices), + "selected_frame_names": list(result.preview.selected_frame_names), + "selected_source_frame_indices": list( + result.preview.selected_source_frame_indices + ), + "frame_results": [ + { + "frame_index": int(frame_result.frame_index), + "time_fs": ( + None + if frame_result.time_fs is None + else float(frame_result.time_fs) + ), + } + for frame_result in result.frame_results + ], + "cluster_labels": list(result.cluster_labels), + "cluster_sizes": { + str(label): int(size) + for label, size in result.cluster_sizes.items() + }, + "bin_edges_fs": result.bin_edges_fs.tolist(), + "frames_per_bin": result.frames_per_bin.tolist(), + "total_clusters_per_bin": result.total_clusters_per_bin.tolist(), + "raw_count_matrix": result.raw_count_matrix.tolist(), + "fraction_matrix": result.fraction_matrix.tolist(), + "mean_count_matrix": result.mean_count_matrix.tolist(), + "frame_count_matrix": result.frame_count_matrix.tolist(), + "total_clusters_per_frame": result.total_clusters_per_frame.tolist(), + "lifetime_by_label": [ + asdict(entry) for entry in result.lifetime_by_label + ], + "lifetime_by_size": [ + asdict(entry) for entry in result.lifetime_by_size + ], + "energy_data": _serialize_energy_data(result.energy_data), + } + + +def _deserialize_cluster_dynamics_result( + payload: object, + *, + fallback_path: Path, +) -> ClusterDynamicsResult: + result_payload = dict(payload if isinstance(payload, dict) else {}) + preview_payload = dict(result_payload.get("preview", {})) + preview_summary = dict(result_payload.get("preview_summary", {})) + frame_results_payload = result_payload.get("frame_results", []) + frame_results = tuple( + FrameClusterResult( + frame_index=int(entry["frame_index"]), + time_fs=( + None + if entry.get("time_fs") is None + else float(entry.get("time_fs")) + ), + clusters=[], + ) + for entry in frame_results_payload + if isinstance(entry, dict) and "frame_index" in entry + ) + frame_timestep_fs = float(preview_payload.get("frame_timestep_fs", 0.5)) + ( + frames_per_colormap_timestep, + colormap_timestep_fs, + ) = _resolve_colormap_timestep_settings( + frame_timestep_fs=frame_timestep_fs, + frames_per_colormap_timestep=preview_payload.get( + "frames_per_colormap_timestep" + ), + colormap_timestep_fs=_optional_float( + preview_payload.get("colormap_timestep_fs") + ), + legacy_bin_size_fs=_optional_float(preview_payload.get("bin_size_fs")), + require_integral_ratio=False, + ) + preview = ClusterDynamicsSelectionPreview( + summary=preview_summary, + frame_format=str(preview_payload.get("frame_format", "xyz")), + resolved_box_dimensions=_coerce_box_dimensions( + preview_payload.get("resolved_box_dimensions") + ), + use_pbc=bool(preview_payload.get("use_pbc", False)), + first_frame_time_fs=float( + preview_payload.get("first_frame_time_fs", 0.0) + ), + frame_timestep_fs=frame_timestep_fs, + frames_per_colormap_timestep=frames_per_colormap_timestep, + colormap_timestep_fs=colormap_timestep_fs, + analysis_start_fs=float(preview_payload.get("analysis_start_fs", 0.0)), + analysis_stop_fs=float(preview_payload.get("analysis_stop_fs", 0.0)), + first_selected_time_fs=_optional_float( + preview_payload.get("first_selected_time_fs") + ), + last_selected_time_fs=_optional_float( + preview_payload.get("last_selected_time_fs") + ), + selected_frame_indices=tuple( + int(value) + for value in result_payload.get("selected_frame_indices", []) + ), + selected_frame_names=tuple( + str(value) + for value in result_payload.get("selected_frame_names", []) + ), + selected_source_frame_indices=tuple( + None if value is None else int(value) + for value in result_payload.get( + "selected_source_frame_indices", [] + ) + ), + energy_file=_optional_path(preview_payload.get("energy_file")), + folder_start_time_fs=_optional_float( + preview_payload.get("folder_start_time_fs") + ), + folder_start_time_source=_optional_str( + preview_payload.get("folder_start_time_source") + ), + time_source_label=str( + preview_payload.get("time_source_label", "Saved dataset") + ), + time_warnings=tuple( + str(value) for value in preview_payload.get("time_warnings", []) + ), + ) + return ClusterDynamicsResult( + preview=preview, + frame_results=frame_results, + bin_edges_fs=np.asarray( + result_payload.get("bin_edges_fs", []), dtype=float + ), + frames_per_bin=np.asarray( + result_payload.get("frames_per_bin", []), + dtype=float, + ), + total_clusters_per_bin=np.asarray( + result_payload.get("total_clusters_per_bin", []), + dtype=float, + ), + cluster_labels=tuple( + str(value) for value in result_payload.get("cluster_labels", []) + ), + cluster_sizes={ + str(label): int(size) + for label, size in dict( + result_payload.get("cluster_sizes", {}) + ).items() + }, + raw_count_matrix=np.asarray( + result_payload.get("raw_count_matrix", []), + dtype=float, + ), + fraction_matrix=np.asarray( + result_payload.get("fraction_matrix", []), + dtype=float, + ), + mean_count_matrix=np.asarray( + result_payload.get("mean_count_matrix", []), + dtype=float, + ), + frame_count_matrix=np.asarray( + result_payload.get("frame_count_matrix", []), + dtype=float, + ), + total_clusters_per_frame=np.asarray( + result_payload.get("total_clusters_per_frame", []), + dtype=float, + ), + lifetime_by_label=tuple( + ClusterLifetimeSummary(**_normalize_summary_payload(entry)) + for entry in result_payload.get("lifetime_by_label", []) + if isinstance(entry, dict) + ), + lifetime_by_size=tuple( + ClusterSizeLifetimeSummary(**_normalize_summary_payload(entry)) + for entry in result_payload.get("lifetime_by_size", []) + if isinstance(entry, dict) + ), + energy_data=_deserialize_energy_data( + result_payload.get("energy_data"), + fallback_path=fallback_path, + ), + ) + + +def _write_prediction_csv( + result: ClusterDynamicsMLResult, + dataset_file: Path, +) -> Path: + output_path = dataset_file.with_name( + f"{dataset_file.stem}_predictions.csv" + ) + with output_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.writer(handle) + writer.writerow( + [ + "target_node_count", + "rank", + "label", + "predicted_population_share", + "predicted_mean_count_per_frame", + "predicted_occupancy_fraction", + "predicted_mean_lifetime_fs", + "predicted_association_rate_per_ps", + "predicted_dissociation_rate_per_ps", + "predicted_mean_radius_of_gyration", + "predicted_mean_max_radius", + "predicted_stability_score", + "source_label", + "notes", + ] + ) + for entry in result.predictions: + writer.writerow( + [ + int(entry.target_node_count), + int(entry.rank), + entry.label, + float(entry.predicted_population_share), + float(entry.predicted_mean_count_per_frame), + float(entry.predicted_occupancy_fraction), + float(entry.predicted_mean_lifetime_fs), + float(entry.predicted_association_rate_per_ps), + float(entry.predicted_dissociation_rate_per_ps), + float(entry.predicted_mean_radius_of_gyration), + float(entry.predicted_mean_max_radius), + float(entry.predicted_stability_score), + "" if entry.source_label is None else entry.source_label, + entry.notes, + ] + ) + return output_path + + +def _write_histogram_csvs( + result: ClusterDynamicsMLResult, + dataset_file: Path, +) -> list[Path]: + output_paths = [ + ( + dataset_file.with_name( + f"{dataset_file.stem}_observed_histogram.csv" + ), + False, + ), + ( + dataset_file.with_name( + f"{dataset_file.stem}_observed_plus_surrogate_histogram.csv" + ), + True, + ), + ] + written_files: list[Path] = [] + for output_path, include_predictions in output_paths: + written_files.append( + _write_histogram_csv( + result, + output_path, + include_predictions=include_predictions, + ) + ) + return written_files + + +def _write_histogram_csv( + result: ClusterDynamicsMLResult, + output_path: Path, + *, + include_predictions: bool, +) -> Path: + rows = _distribution_rows( + result, + include_predictions=include_predictions, + ) + with output_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.writer(handle) + writer.writerow( + [ + "label", + "source", + "node_count", + "mean_lifetime_fs", + "mean_count_per_frame", + "mean_max_radius", + "raw_weight", + "normalized_weight", + ] + ) + for row in rows: + writer.writerow( + [ + row["label"], + row["source"], + float(row["node_count"]), + ( + "" + if row["mean_lifetime_fs"] is None + else float(row["mean_lifetime_fs"]) + ), + float(row["mean_count_per_frame"]), + float(row["mean_max_radius"]), + float(row["raw_weight"]), + float(row["normalized_weight"]), + ] + ) + return output_path + + +def _write_saxs_csv( + comparison: ClusterDynamicsMLSAXSComparison, + dataset_file: Path, +) -> Path: + output_path = dataset_file.with_name(f"{dataset_file.stem}_saxs.csv") + with output_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.writer(handle) + writer.writerow( + [ + "q", + "raw_model_intensity", + "fitted_model_intensity", + "experimental_intensity", + "residual", + ] + ) + experimental = comparison.experimental_intensity + residuals = comparison.residuals + for index, q_value in enumerate(comparison.q_values): + writer.writerow( + [ + float(q_value), + float(comparison.raw_model_intensity[index]), + float(comparison.fitted_model_intensity[index]), + ( + "" + if experimental is None + else float(experimental[index]) + ), + "" if residuals is None else float(residuals[index]), + ] + ) + return output_path + + +def _write_saxs_component_profiles( + comparison: ClusterDynamicsMLSAXSComparison, + dataset_file: Path, +) -> list[Path]: + output_dir = dataset_file.with_name(f"{dataset_file.stem}_saxs_components") + output_dir.mkdir(parents=True, exist_ok=True) + copied_files: list[Path] = [] + seen_sources: set[Path] = set() + for entry in comparison.component_weights: + source_path = entry.profile_path + if source_path is None: + continue + resolved_source = source_path.expanduser().resolve() + if resolved_source in seen_sources or not resolved_source.is_file(): + continue + seen_sources.add(resolved_source) + target_path = _unique_child_path(output_dir, resolved_source.name) + shutil.copy2(resolved_source, target_path) + copied_files.append(target_path) + return copied_files + + +def _write_prediction_structures( + predictions: tuple[PredictedClusterCandidate, ...], + dataset_file: Path, +) -> list[Path]: + output_dir = dataset_file.with_name( + f"{dataset_file.stem}_predicted_structures" + ) + output_dir.mkdir(parents=True, exist_ok=True) + written_files: list[Path] = [] + for entry in predictions: + output_path = output_dir / ( + f"{entry.target_node_count:02d}_rank{entry.rank:02d}_{entry.label}.xyz" + ) + lines = [ + f"{len(entry.generated_elements)}", + ( + f"label={entry.label} target_node_count={entry.target_node_count} " + f"rank={entry.rank}" + ), + ] + for element, coords in zip( + entry.generated_elements, + entry.generated_coordinates, + strict=False, + ): + lines.append( + f"{element} {float(coords[0]):.8f} {float(coords[1]):.8f} " + f"{float(coords[2]):.8f}" + ) + output_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + written_files.append(output_path) + return written_files + + +def _distribution_rows( + result: ClusterDynamicsMLResult, + *, + include_predictions: bool, +) -> list[dict[str, object]]: + rows: list[dict[str, object]] = [] + observed_weights, predicted_weights = _resolved_population_weights( + result.training_observations, + result.predictions, + frame_timestep_fs=float( + result.dynamics_result.preview.frame_timestep_fs + ), + ) + for row, weight in zip( + result.training_observations, + observed_weights, + strict=False, + ): + if weight <= 0.0: + continue + rows.append( + { + "label": row.label, + "source": "observed", + "node_count": float(row.node_count), + "mean_lifetime_fs": row.mean_lifetime_fs, + "mean_count_per_frame": float(row.mean_count_per_frame), + "mean_max_radius": float(row.mean_max_radius), + "raw_weight": weight, + } + ) + if include_predictions: + for item, weight in zip( + result.predictions, + predicted_weights, + strict=False, + ): + if weight <= 0.0: + continue + rows.append( + { + "label": item.label, + "source": "predicted", + "node_count": float(item.target_node_count), + "mean_lifetime_fs": float(item.predicted_mean_lifetime_fs), + "mean_count_per_frame": float( + item.predicted_mean_count_per_frame + ), + "mean_max_radius": float(item.predicted_mean_max_radius), + "raw_weight": weight, + } + ) + total_weight = sum(float(row["raw_weight"]) for row in rows) + if total_weight <= 0.0: + return [] + for row in rows: + row["normalized_weight"] = float(row["raw_weight"]) / total_weight + return rows + + +def _unique_child_path(directory: Path, filename: str) -> Path: + candidate = directory / filename + if not candidate.exists(): + return candidate + stem = candidate.stem + suffix = candidate.suffix + counter = 2 + while True: + alternative = directory / f"{stem}_{counter}{suffix}" + if not alternative.exists(): + return alternative + counter += 1 + + +def _serialize_energy_data( + energy_data: CP2KEnergyData | None, +) -> dict[str, object] | None: + if energy_data is None: + return None + return { + "filepath": str(energy_data.filepath), + "step": energy_data.step.tolist(), + "time_fs": energy_data.time_fs.tolist(), + "kinetic": energy_data.kinetic.tolist(), + "temperature": energy_data.temperature.tolist(), + "potential": energy_data.potential.tolist(), + } + + +def _deserialize_energy_data( + payload: object, + *, + fallback_path: Path, +) -> CP2KEnergyData | None: + if not isinstance(payload, dict): + return None + filepath_value = payload.get("filepath") + filepath = ( + fallback_path if filepath_value is None else Path(str(filepath_value)) + ) + return CP2KEnergyData( + filepath=filepath, + step=np.asarray(payload.get("step", []), dtype=float), + time_fs=np.asarray(payload.get("time_fs", []), dtype=float), + kinetic=np.asarray(payload.get("kinetic", []), dtype=float), + temperature=np.asarray(payload.get("temperature", []), dtype=float), + potential=np.asarray(payload.get("potential", []), dtype=float), + ) + + +def _normalize_summary_payload( + payload: dict[str, object] +) -> dict[str, object]: + normalized = dict(payload) + if ( + "censored_lifetime_count" in normalized + and "window_truncated_lifetime_count" not in normalized + ): + normalized["window_truncated_lifetime_count"] = normalized.pop( + "censored_lifetime_count" + ) + return normalized + + +def _coerce_box_dimensions( + value: object, +) -> tuple[float, float, float] | None: + if value is None: + return None + components = tuple(float(component) for component in value) + if len(components) != 3: + raise ValueError("Saved box dimensions must contain three values.") + return components + + +def _optional_float(value: object) -> float | None: + return None if value is None else float(value) + + +def _optional_int(value: object) -> int | None: + return None if value is None else int(value) + + +def _optional_str(value: object) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None + + +def _optional_path(value: object) -> Path | None: + text = _optional_str(value) + return None if text is None else Path(text) + + +__all__ = [ + "LoadedClusterDynamicsMLDataset", + "SavedClusterDynamicsMLDataset", + "load_cluster_dynamicsai_dataset", + "save_cluster_dynamicsai_dataset", +] diff --git a/src/saxshell/clusterdynamicsml/ui/__init__.py b/src/saxshell/clusterdynamicsml/ui/__init__.py new file mode 100644 index 0000000..c53ce12 --- /dev/null +++ b/src/saxshell/clusterdynamicsml/ui/__init__.py @@ -0,0 +1,6 @@ +from .main_window import ( + ClusterDynamicsMLMainWindow, + launch_clusterdynamicsml_ui, +) + +__all__ = ["ClusterDynamicsMLMainWindow", "launch_clusterdynamicsml_ui"] diff --git a/src/saxshell/clusterdynamicsml/ui/main_window.py b/src/saxshell/clusterdynamicsml/ui/main_window.py new file mode 100644 index 0000000..942e37a --- /dev/null +++ b/src/saxshell/clusterdynamicsml/ui/main_window.py @@ -0,0 +1,2602 @@ +from __future__ import annotations + +import argparse +import csv +import json +import shutil +import sys +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import numpy as np +from PySide6.QtCore import QObject, Qt, QThread, Signal, Slot +from PySide6.QtWidgets import ( + QAbstractItemView, + QApplication, + QCheckBox, + QDoubleSpinBox, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QHeaderView, + QLabel, + QLineEdit, + QMainWindow, + QMessageBox, + QPushButton, + QScrollArea, + QSpinBox, + QSplitter, + QTableWidget, + QTableWidgetItem, + QTabWidget, + QTextEdit, + QVBoxLayout, + QWidget, +) + +from saxshell.cluster import ( + ExtractedFrameFolderClusterAnalyzer, + detect_frame_folder_mode, + format_box_dimensions, + format_search_mode_label, + frame_folder_label, +) +from saxshell.cluster.ui.definitions_panel import ClusterDefinitionsPanel +from saxshell.cluster.ui.trajectory_panel import ClusterTrajectoryPanel +from saxshell.clusterdynamics.dataset import ( + export_cluster_dynamics_colormap_csv, +) +from saxshell.clusterdynamics.report import ( + default_powerpoint_report_path, + export_cluster_dynamicsai_report_pptx, +) +from saxshell.clusterdynamics.ui.main_window import ( + ClusterDynamicsDatasetPanel, + ClusterDynamicsRunPanel, + ClusterDynamicsTimePanel, +) +from saxshell.clusterdynamics.ui.plot_panel import ClusterDynamicsPlotPanel +from saxshell.saxs.project_manager import ( + PowerPointExportSettings, + SAXSProjectManager, + build_project_paths, +) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) + +from ..dataset import ( + LoadedClusterDynamicsMLDataset, + SavedClusterDynamicsMLDataset, + load_cluster_dynamicsai_dataset, + save_cluster_dynamicsai_dataset, +) +from ..workflow import ( + ClusterDynamicsMLResult, + ClusterDynamicsMLWorkflow, + _resolved_population_weights, +) +from .plot_panel import ( + ClusterDynamicsMLHistogramPanel, + ClusterDynamicsMLPlotPanel, +) + +_OPEN_WINDOWS: list["ClusterDynamicsMLMainWindow"] = [] + + +@dataclass(slots=True) +class ClusterDynamicsMLJobConfig: + frames_dir: Path + clusters_dir: Path | None + project_dir: Path | None + experimental_data_file: Path | None + energy_file: Path | None + atom_type_definitions: dict[str, list[tuple[str, str | None]]] + pair_cutoff_definitions: dict[tuple[str, str], dict[int, float]] + box_dimensions: tuple[float, float, float] | None + use_pbc: bool + default_cutoff: float | None + shell_levels: tuple[int, ...] + shared_shells: bool + include_shell_atoms_in_stoichiometry: bool + search_mode: str + folder_start_time_fs: float | None + first_frame_time_fs: float + frame_timestep_fs: float + frames_per_colormap_timestep: int + analysis_start_fs: float | None + analysis_stop_fs: float | None + target_node_counts: tuple[int, ...] + candidates_per_size: int + prediction_population_share_threshold: float + q_min: float | None + q_max: float | None + q_points: int + + +@dataclass(slots=True) +class _ProjectPredictionHistoryEntry: + dataset_file: Path + modified_time: float + saved_label: str + target_node_counts: tuple[int, ...] + candidates_per_size: int + prediction_population_share_threshold: float + prediction_count: int + max_predicted_node_count: int | None + rmse: float | None + + +class ClusterDynamicsMLWorker(QObject): + progress = Signal(str) + finished = Signal(object) + failed = Signal(str) + + def __init__(self, config: ClusterDynamicsMLJobConfig) -> None: + super().__init__() + self.config = config + + @Slot() + def run(self) -> None: + try: + workflow = ClusterDynamicsMLWorkflow( + self.config.frames_dir, + atom_type_definitions=self.config.atom_type_definitions, + pair_cutoff_definitions=self.config.pair_cutoff_definitions, + clusters_dir=self.config.clusters_dir, + project_dir=self.config.project_dir, + experimental_data_file=self.config.experimental_data_file, + box_dimensions=self.config.box_dimensions, + use_pbc=self.config.use_pbc, + default_cutoff=self.config.default_cutoff, + shell_levels=self.config.shell_levels, + shared_shells=self.config.shared_shells, + include_shell_atoms_in_stoichiometry=( + self.config.include_shell_atoms_in_stoichiometry + ), + search_mode=self.config.search_mode, + folder_start_time_fs=self.config.folder_start_time_fs, + first_frame_time_fs=self.config.first_frame_time_fs, + frame_timestep_fs=self.config.frame_timestep_fs, + frames_per_colormap_timestep=( + self.config.frames_per_colormap_timestep + ), + analysis_start_fs=self.config.analysis_start_fs, + analysis_stop_fs=self.config.analysis_stop_fs, + energy_file=self.config.energy_file, + target_node_counts=self.config.target_node_counts, + candidates_per_size=self.config.candidates_per_size, + prediction_population_share_threshold=( + self.config.prediction_population_share_threshold + ), + q_min=self.config.q_min, + q_max=self.config.q_max, + q_points=self.config.q_points, + ) + preview = workflow.preview_selection() + self.progress.emit( + "Preparing clusterdynamicsml analysis.\n" + f"Frames selected: {preview.dynamics_preview.selected_frames}\n" + f"Observed node counts: {preview.observed_node_counts or ('n/a',)}\n" + f"Target node counts: {preview.target_node_counts or ('n/a',)}" + ) + result = workflow.analyze(progress_callback=self.progress.emit) + self.finished.emit(result) + except Exception as exc: + self.failed.emit(str(exc)) + + +class ClusterDynamicsMLSettingsPanel(QGroupBox): + settings_changed = Signal() + + def __init__(self) -> None: + super().__init__("Prediction Inputs") + self._build_ui() + + def _build_ui(self) -> None: + layout = QFormLayout(self) + self.setToolTip( + "Choose the observed smaller-cluster structures to learn from, " + "optionally load experimental SAXS data for comparison, and set " + "the larger node counts to predict." + ) + + clusters_tooltip = ( + "Folder of observed smaller-cluster structures. Use the cluster " + "extraction output organized by stoichiometry label, for example " + "Pb/, Pb2I/, Pb3I2/ (optionally with motif_* subfolders). " + "When this tool is opened from the main SAXSShell UI, the active " + "project's clusters folder can be filled automatically." + ) + experimental_tooltip = ( + "Optional experimental SAXS data file used to fit and compare the " + "cluster-only surrogate SAXS trace. Leave this blank to run the " + "prediction workflow without an experimental comparison." + ) + target_start_tooltip = ( + "Lowest node count to predict. clusterdynamicsml extrapolates " + "beyond the observed node counts in the smaller-cluster training " + "set." + ) + target_stop_tooltip = ( + "Highest node count to predict. Every integer node count between " + "the start and stop values is included." + ) + candidates_tooltip = ( + "Number of ranked candidate stoichiometries to keep for each " + "predicted node count." + ) + share_threshold_tooltip = ( + "Minimum share among the predicted candidates used when " + "highlighting the largest practical predicted node count and " + "filtering out tiny predicted populations." + ) + store_history_tooltip = ( + "When enabled and a project folder is set, each clusterdynamicsml " + "run is cached as its own timestamped result bundle so you can " + "compare prediction settings and fitted SAXS models later." + ) + q_min_tooltip = ( + "Fallback minimum q value for the surrogate SAXS comparison when " + "no experimental data file is loaded." + ) + q_max_tooltip = ( + "Fallback maximum q value for the surrogate SAXS comparison when " + "no experimental data file is loaded." + ) + q_points_tooltip = ( + "Number of q samples in the fallback SAXS grid when no " + "experimental data file is loaded." + ) + + self.clusters_dir_edit = QLineEdit() + self.clusters_dir_edit.setToolTip(clusters_tooltip) + self.clusters_dir_edit.textChanged.connect( + lambda _text: self.settings_changed.emit() + ) + self._add_tooltipped_row( + layout, + "Clusters folder", + self._make_dir_row( + self.clusters_dir_edit, "Select clusters folder" + ), + clusters_tooltip, + buddy=self.clusters_dir_edit, + ) + + self.experimental_data_edit = QLineEdit() + self.experimental_data_edit.setToolTip(experimental_tooltip) + self.experimental_data_edit.textChanged.connect( + lambda _text: self.settings_changed.emit() + ) + self._add_tooltipped_row( + layout, + "Experimental data", + self._make_file_row( + self.experimental_data_edit, + "Select experimental SAXS data", + ), + experimental_tooltip, + buddy=self.experimental_data_edit, + ) + + self.target_start_spin = QSpinBox() + self.target_start_spin.setRange(1, 999) + self.target_start_spin.setValue(4) + self.target_start_spin.setToolTip(target_start_tooltip) + self.target_start_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + self._add_tooltipped_row( + layout, + "Predict from node count", + self.target_start_spin, + target_start_tooltip, + ) + + self.target_stop_spin = QSpinBox() + self.target_stop_spin.setRange(1, 999) + self.target_stop_spin.setValue(5) + self.target_stop_spin.setToolTip(target_stop_tooltip) + self.target_stop_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + self._add_tooltipped_row( + layout, + "Predict through node count", + self.target_stop_spin, + target_stop_tooltip, + ) + + self.candidates_spin = QSpinBox() + self.candidates_spin.setRange(1, 12) + self.candidates_spin.setValue(3) + self.candidates_spin.setToolTip(candidates_tooltip) + self.candidates_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + self._add_tooltipped_row( + layout, + "Candidates / size", + self.candidates_spin, + candidates_tooltip, + ) + + self.share_threshold_spin = QDoubleSpinBox() + self.share_threshold_spin.setDecimals(3) + self.share_threshold_spin.setRange(0.0, 1.0) + self.share_threshold_spin.setSingleStep(0.01) + self.share_threshold_spin.setValue(0.02) + self.share_threshold_spin.setToolTip(share_threshold_tooltip) + self.share_threshold_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + self._add_tooltipped_row( + layout, + "Share threshold", + self.share_threshold_spin, + share_threshold_tooltip, + ) + + self.store_history_checkbox = QCheckBox( + "Keep timestamped prediction history" + ) + self.store_history_checkbox.setChecked(True) + self.store_history_checkbox.setToolTip(store_history_tooltip) + self.store_history_checkbox.toggled.connect( + lambda _checked: self.settings_changed.emit() + ) + self._add_tooltipped_row( + layout, + "Store history", + self.store_history_checkbox, + store_history_tooltip, + ) + + self.q_min_spin = QDoubleSpinBox() + self.q_min_spin.setDecimals(4) + self.q_min_spin.setRange(0.0, 100.0) + self.q_min_spin.setSingleStep(0.01) + self.q_min_spin.setValue(0.02) + self.q_min_spin.setToolTip(q_min_tooltip) + self.q_min_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + self._add_tooltipped_row( + layout, + "Fallback q min", + self.q_min_spin, + q_min_tooltip, + ) + + self.q_max_spin = QDoubleSpinBox() + self.q_max_spin.setDecimals(4) + self.q_max_spin.setRange(0.0, 100.0) + self.q_max_spin.setSingleStep(0.05) + self.q_max_spin.setValue(1.20) + self.q_max_spin.setToolTip(q_max_tooltip) + self.q_max_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + self._add_tooltipped_row( + layout, + "Fallback q max", + self.q_max_spin, + q_max_tooltip, + ) + + self.q_points_spin = QSpinBox() + self.q_points_spin.setRange(10, 20000) + self.q_points_spin.setValue(250) + self.q_points_spin.setToolTip(q_points_tooltip) + self.q_points_spin.valueChanged.connect( + lambda _value: self.settings_changed.emit() + ) + self._add_tooltipped_row( + layout, + "Fallback q points", + self.q_points_spin, + q_points_tooltip, + ) + + def _add_tooltipped_row( + self, + layout: QFormLayout, + label_text: str, + field_widget: QWidget, + tooltip: str, + *, + buddy: QWidget | None = None, + ) -> None: + label = QLabel(label_text) + label.setToolTip(tooltip) + label.setBuddy(field_widget if buddy is None else buddy) + field_widget.setToolTip(tooltip) + layout.addRow(label, field_widget) + + def _make_dir_row(self, line_edit: QLineEdit, title: str) -> QWidget: + widget = QWidget() + row = QHBoxLayout(widget) + row.setContentsMargins(0, 0, 0, 0) + button = QPushButton("Browse") + button.setToolTip(line_edit.toolTip()) + button.clicked.connect( + lambda _checked=False: self._choose_directory(line_edit, title) + ) + row.addWidget(line_edit) + row.addWidget(button) + return widget + + def _make_file_row(self, line_edit: QLineEdit, title: str) -> QWidget: + widget = QWidget() + row = QHBoxLayout(widget) + row.setContentsMargins(0, 0, 0, 0) + button = QPushButton("Browse") + button.setToolTip(line_edit.toolTip()) + button.clicked.connect( + lambda _checked=False: self._choose_file(line_edit, title) + ) + row.addWidget(line_edit) + row.addWidget(button) + return widget + + def _choose_directory(self, line_edit: QLineEdit, title: str) -> None: + path = QFileDialog.getExistingDirectory( + self, title, line_edit.text().strip() + ) + if path: + line_edit.setText(path) + + def _choose_file(self, line_edit: QLineEdit, title: str) -> None: + path, _selected_filter = QFileDialog.getOpenFileName( + self, + title, + line_edit.text().strip(), + "Data Files (*.dat *.txt *.csv);;All Files (*)", + ) + if path: + line_edit.setText(path) + + def clusters_dir(self) -> Path | None: + text = self.clusters_dir_edit.text().strip() + return None if not text else Path(text) + + def set_clusters_dir( + self, path: Path | None, *, emit_signal: bool = True + ) -> None: + self.clusters_dir_edit.blockSignals(True) + self.clusters_dir_edit.setText("" if path is None else str(path)) + self.clusters_dir_edit.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + def experimental_data_file(self) -> Path | None: + text = self.experimental_data_edit.text().strip() + return None if not text else Path(text) + + def set_experimental_data_file( + self, + path: Path | None, + *, + emit_signal: bool = True, + ) -> None: + self.experimental_data_edit.blockSignals(True) + self.experimental_data_edit.setText("" if path is None else str(path)) + self.experimental_data_edit.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + def target_node_counts(self) -> tuple[int, ...]: + start = int(self.target_start_spin.value()) + stop = int(self.target_stop_spin.value()) + lower = min(start, stop) + upper = max(start, stop) + return tuple(range(lower, upper + 1)) + + def set_target_node_counts( + self, + values: tuple[int, ...], + *, + emit_signal: bool = True, + ) -> None: + if not values: + return + self.target_start_spin.blockSignals(True) + self.target_stop_spin.blockSignals(True) + self.target_start_spin.setValue(min(values)) + self.target_stop_spin.setValue(max(values)) + self.target_start_spin.blockSignals(False) + self.target_stop_spin.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + def candidates_per_size(self) -> int: + return int(self.candidates_spin.value()) + + def set_candidates_per_size( + self, + value: int, + *, + emit_signal: bool = True, + ) -> None: + self.candidates_spin.blockSignals(True) + self.candidates_spin.setValue(max(int(value), 1)) + self.candidates_spin.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + def prediction_population_share_threshold(self) -> float: + return float(self.share_threshold_spin.value()) + + def set_prediction_population_share_threshold( + self, + value: float, + *, + emit_signal: bool = True, + ) -> None: + self.share_threshold_spin.blockSignals(True) + self.share_threshold_spin.setValue(max(float(value), 0.0)) + self.share_threshold_spin.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + def store_prediction_history(self) -> bool: + return bool(self.store_history_checkbox.isChecked()) + + def set_store_prediction_history( + self, + value: bool, + *, + emit_signal: bool = True, + ) -> None: + self.store_history_checkbox.blockSignals(True) + self.store_history_checkbox.setChecked(bool(value)) + self.store_history_checkbox.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + def q_min(self) -> float: + return float(self.q_min_spin.value()) + + def q_max(self) -> float: + return float(self.q_max_spin.value()) + + def q_points(self) -> int: + return int(self.q_points_spin.value()) + + def set_q_settings( + self, + *, + q_min: float, + q_max: float, + q_points: int, + emit_signal: bool = True, + ) -> None: + self.q_min_spin.blockSignals(True) + self.q_max_spin.blockSignals(True) + self.q_points_spin.blockSignals(True) + self.q_min_spin.setValue(float(q_min)) + self.q_max_spin.setValue(float(q_max)) + self.q_points_spin.setValue(max(int(q_points), 2)) + self.q_min_spin.blockSignals(False) + self.q_max_spin.blockSignals(False) + self.q_points_spin.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + + +class ClusterDynamicsMLMainWindow(QMainWindow): + def __init__( + self, + initial_frames_dir: Path | None = None, + initial_energy_file: Path | None = None, + initial_project_dir: Path | None = None, + initial_clusters_dir: Path | None = None, + initial_experimental_data_file: Path | None = None, + ) -> None: + super().__init__() + self._project_manager = SAXSProjectManager() + self._last_summary: dict[str, object] | None = None + self._frame_format: str | None = None + self._last_result: ClusterDynamicsMLResult | None = None + self._last_dataset_file: Path | None = None + self._project_history_entries: list[_ProjectPredictionHistoryEntry] = ( + [] + ) + self._run_thread: QThread | None = None + self._run_worker: ClusterDynamicsMLWorker | None = None + self._suspend_preview_refresh = False + self._initializing = True + self._restoring_project_dataset = False + self._build_ui() + + if initial_frames_dir is not None: + self.trajectory_panel.frames_dir_edit.setText( + str(initial_frames_dir) + ) + if initial_energy_file is not None: + self.run_panel.energy_path_edit.setText(str(initial_energy_file)) + if initial_project_dir is not None: + self.dataset_panel.set_project_dir(initial_project_dir) + if initial_clusters_dir is not None: + self.prediction_panel.set_clusters_dir( + initial_clusters_dir, emit_signal=False + ) + if initial_experimental_data_file is not None: + self.prediction_panel.set_experimental_data_file( + initial_experimental_data_file, + emit_signal=False, + ) + self._sync_project_defaults() + restored = self._restore_latest_project_result(announce=False) + self._initializing = False + if not restored: + self._refresh_project_history_view() + self._refresh_selection_preview() + + def _build_ui(self) -> None: + self.setWindowTitle("SAXSShell (clusterdynamicsml)") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1640, 960) + + central = QWidget() + root = QHBoxLayout(central) + root.setContentsMargins(8, 8, 8, 8) + + splitter = QSplitter(Qt.Orientation.Horizontal) + + left = QWidget() + left_layout = QVBoxLayout(left) + left_layout.setContentsMargins(0, 0, 0, 0) + left_layout.setSpacing(12) + + self.trajectory_panel = ClusterTrajectoryPanel() + self.time_panel = ClusterDynamicsTimePanel() + self.definitions_panel = ClusterDefinitionsPanel() + self.prediction_panel = ClusterDynamicsMLSettingsPanel() + self.run_panel = ClusterDynamicsRunPanel() + self.dataset_panel = ClusterDynamicsDatasetPanel() + + self.run_panel.analyze_button.setText( + "Analyze and Predict Larger Clusters" + ) + + left_layout.addWidget(self.trajectory_panel) + left_layout.addWidget(self.time_panel) + left_layout.addWidget(self.definitions_panel) + left_layout.addWidget(self.prediction_panel) + left_layout.addWidget(self.run_panel) + left_layout.addWidget(self.dataset_panel) + left_layout.addStretch(1) + + right = QWidget() + right_layout = QVBoxLayout(right) + right_layout.setContentsMargins(0, 0, 0, 0) + right_layout.setSpacing(12) + + self.dynamics_plot_panel = ClusterDynamicsPlotPanel() + right_layout.addWidget(self.dynamics_plot_panel, stretch=2) + + self.results_tabs = QTabWidget() + self.summary_tab = QWidget() + summary_layout = QVBoxLayout(self.summary_tab) + summary_layout.setContentsMargins(0, 0, 0, 0) + summary_layout.setSpacing(8) + self.summary_box = QTextEdit() + self.summary_box.setReadOnly(True) + self.summary_box.setMinimumHeight(170) + summary_layout.addWidget(self.summary_box, stretch=1) + history_group = QGroupBox("Prediction History") + history_layout = QVBoxLayout(history_group) + history_layout.setContentsMargins(8, 8, 8, 8) + history_layout.setSpacing(8) + self.history_status_label = QLabel( + "Select a project folder to compare saved prediction runs. " + "The most recent run is plotted by default." + ) + self.history_status_label.setWordWrap(True) + history_layout.addWidget(self.history_status_label) + self.history_table = self._build_table( + ( + "Loaded", + "Saved", + "Target Nodes", + "Candidates / Size", + "Share Threshold", + "Predicted", + "Max Pred Nodes", + "RMSE", + "Dataset", + ) + ) + self.history_table.setMinimumHeight(220) + self.history_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.history_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.history_table.setEditTriggers( + QAbstractItemView.EditTrigger.NoEditTriggers + ) + self.history_table.itemSelectionChanged.connect( + self._update_history_controls + ) + self.history_table.itemDoubleClicked.connect( + lambda _item: self._load_selected_history_entry() + ) + history_header = self.history_table.horizontalHeader() + history_header.setSectionResizeMode( + 0, + QHeaderView.ResizeMode.ResizeToContents, + ) + history_header.setSectionResizeMode( + 1, + QHeaderView.ResizeMode.ResizeToContents, + ) + history_header.setSectionResizeMode( + 8, + QHeaderView.ResizeMode.Stretch, + ) + history_layout.addWidget(self.history_table, stretch=1) + history_button_row = QHBoxLayout() + history_button_row.setContentsMargins(0, 0, 0, 0) + self.history_load_button = QPushButton("Plot Selected Prediction") + self.history_load_button.clicked.connect( + self._load_selected_history_entry + ) + history_button_row.addWidget(self.history_load_button) + self.history_refresh_button = QPushButton("Refresh History") + self.history_refresh_button.clicked.connect( + self._refresh_project_history_view + ) + history_button_row.addWidget(self.history_refresh_button) + history_button_row.addStretch(1) + history_layout.addLayout(history_button_row) + summary_layout.addWidget(history_group, stretch=1) + self.lifetime_table = self._build_table( + ( + "Type", + "Nodes", + "Size Rank", + "Candidate Rank", + "Label", + "Observed-only Weight (%)", + "Combined Weight (%)", + "Share (%)", + "Mean lifetime (fs)", + "Std lifetime (fs)", + "Completed", + "Window-truncated", + "Assoc. rate (1/ps)", + "Dissoc. rate (1/ps)", + "Occupancy (%)", + "Mean count/frame", + "Reference", + "Notes", + ) + ) + self.histogram_panel = ClusterDynamicsMLHistogramPanel() + self.saxs_panel = ClusterDynamicsMLPlotPanel() + self.observed_histogram_panel = self.histogram_panel + self.combined_histogram_panel = self.histogram_panel + self.surrogate_plot_panel = self.saxs_panel + self.results_tabs.addTab(self.summary_tab, "Summary") + self.results_tabs.addTab(self.lifetime_table, "Lifetimes") + self.results_tabs.addTab(self.histogram_panel, "Histograms") + self.results_tabs.addTab(self.saxs_panel, "SAXS") + right_layout.addWidget(self.results_tabs, stretch=2) + + splitter.addWidget(self._wrap_scroll_area(left)) + splitter.addWidget(right) + splitter.setSizes([530, 1110]) + root.addWidget(splitter) + self.setCentralWidget(central) + self.statusBar().showMessage("Ready") + + self.trajectory_panel.inspect_requested.connect( + self.inspect_frames_folder + ) + self.trajectory_panel.frames_dir_changed.connect( + self._on_frames_dir_changed + ) + self.trajectory_panel.settings_changed.connect( + self._refresh_selection_preview + ) + self.time_panel.settings_changed.connect( + self._refresh_selection_preview + ) + self.definitions_panel.settings_changed.connect( + self._refresh_selection_preview + ) + self.prediction_panel.settings_changed.connect( + self._refresh_selection_preview + ) + self.run_panel.settings_changed.connect( + self._refresh_selection_preview + ) + self.run_panel.analyze_requested.connect(self.run_analysis) + self.dataset_panel.settings_changed.connect( + self._on_project_dir_changed + ) + self.dataset_panel.save_dataset_requested.connect(self.save_dataset) + self.dataset_panel.load_dataset_requested.connect(self.load_dataset) + self.dataset_panel.save_colormap_requested.connect( + self.save_colormap_data + ) + self.dataset_panel.save_lifetime_requested.connect( + self.save_lifetime_table + ) + self.dataset_panel.save_powerpoint_requested.connect( + self.save_powerpoint_report + ) + + self.run_panel.set_selection_summary( + "Select an extracted frames folder and a smaller-cluster " + "structures folder to preview the extrapolation workflow." + ) + self.run_panel.set_log( + "Ready. clusterdynamicsml reuses the time-binned cluster analysis " + "from clusterdynamics, then fits an experimental surrogate to " + "smaller-cluster lifetimes, populations, stoichiometries, and " + "representative structures." + ) + self._set_frame_format(None) + self._update_history_controls() + + def inspect_frames_folder(self) -> None: + frames_dir = self.trajectory_panel.get_frames_dir() + if frames_dir is None: + self._show_error("No extracted frames folder selected.") + return + self._inspect_frames_dir(frames_dir, announce=True) + + def run_analysis(self) -> None: + try: + if self._run_thread is not None: + return + config = self._build_job_config() + self.run_panel.set_log( + "clusterdynamicsml request received.\n" + f"Frames folder: {config.frames_dir}\n" + f"Clusters folder: {config.clusters_dir}\n" + f"Targets: {config.target_node_counts}\n" + f"Experimental data: {config.experimental_data_file}\n" + f"Frame timestep: {config.frame_timestep_fs:.3f} fs" + ) + self.run_panel.progress_label.setText( + "Progress: running surrogate workflow" + ) + self.run_panel.progress_bar.setRange(0, 0) + self.dynamics_plot_panel.set_result(None) + self.histogram_panel.set_result(None) + self.saxs_panel.set_result(None) + self.summary_box.clear() + self.lifetime_table.setRowCount(0) + self.statusBar().showMessage( + "Analyzing and predicting larger clusters..." + ) + self._start_worker(config) + except Exception as exc: + self._handle_error("clusterdynamicsml failed", str(exc)) + + def _build_job_config(self) -> ClusterDynamicsMLJobConfig: + frames_dir = self.trajectory_panel.get_frames_dir() + if frames_dir is None: + raise ValueError("No extracted frames folder selected.") + atom_type_definitions = self.definitions_panel.atom_type_definitions() + if not atom_type_definitions: + raise ValueError( + "Add at least one atom-type definition before running the analysis." + ) + if not ( + atom_type_definitions.get("node") + or atom_type_definitions.get("linker") + ): + raise ValueError("Define at least one node or linker atom type.") + pair_cutoff_definitions = ( + self.definitions_panel.pair_cutoff_definitions() + ) + default_cutoff = self.definitions_panel.default_cutoff() + if not pair_cutoff_definitions and default_cutoff is None: + raise ValueError( + "Add at least one pair-cutoff definition or a default cutoff." + ) + manual_box_dimensions = self.definitions_panel.box_dimensions() + resolved_box_dimensions = manual_box_dimensions + use_pbc = self.definitions_panel.use_pbc() + if use_pbc and resolved_box_dimensions is None: + resolved_box_dimensions = self._detected_box_dimensions() + if resolved_box_dimensions is None: + raise ValueError( + "Periodic boundary conditions are enabled, but no box " + "dimensions are available." + ) + return ClusterDynamicsMLJobConfig( + frames_dir=frames_dir, + clusters_dir=self.prediction_panel.clusters_dir(), + project_dir=self.dataset_panel.project_dir(), + experimental_data_file=self.prediction_panel.experimental_data_file(), + energy_file=self.run_panel.energy_file(), + atom_type_definitions=atom_type_definitions, + pair_cutoff_definitions=pair_cutoff_definitions, + box_dimensions=resolved_box_dimensions, + use_pbc=use_pbc, + default_cutoff=default_cutoff, + shell_levels=self.definitions_panel.shell_growth_levels(), + shared_shells=self.definitions_panel.shared_shells(), + include_shell_atoms_in_stoichiometry=( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + search_mode=self.definitions_panel.search_mode(), + folder_start_time_fs=self.time_panel.folder_start_time_fs(), + first_frame_time_fs=self.time_panel.first_frame_time_fs(), + frame_timestep_fs=self.time_panel.frame_timestep_fs(), + frames_per_colormap_timestep=( + self.time_panel.frames_per_colormap_timestep() + ), + analysis_start_fs=self.time_panel.analysis_start_fs(), + analysis_stop_fs=self.time_panel.analysis_stop_fs(), + target_node_counts=self.prediction_panel.target_node_counts(), + candidates_per_size=self.prediction_panel.candidates_per_size(), + prediction_population_share_threshold=( + self.prediction_panel.prediction_population_share_threshold() + ), + q_min=self.prediction_panel.q_min(), + q_max=self.prediction_panel.q_max(), + q_points=self.prediction_panel.q_points(), + ) + + def _start_worker(self, config: ClusterDynamicsMLJobConfig) -> None: + self._run_thread = QThread(self) + self._run_worker = ClusterDynamicsMLWorker(config) + self._run_worker.moveToThread(self._run_thread) + + self._run_thread.started.connect(self._run_worker.run) + self._run_worker.progress.connect(self.run_panel.append_log) + self._run_worker.finished.connect(self._on_run_finished) + self._run_worker.failed.connect(self._on_run_failed) + self._run_worker.finished.connect(self._run_thread.quit) + self._run_worker.failed.connect(self._run_thread.quit) + self._run_thread.finished.connect(self._cleanup_run_thread) + self._run_thread.finished.connect(self._run_thread.deleteLater) + self._run_thread.finished.connect(self._run_worker.deleteLater) + self._run_thread.start() + + def _on_run_finished(self, result: ClusterDynamicsMLResult) -> None: + self._last_result = result + self.dynamics_plot_panel.set_result(result.dynamics_result) + self.histogram_panel.set_result(result) + self.saxs_panel.set_result(result) + self.run_panel.progress_bar.setRange( + 0, max(result.dynamics_result.analyzed_frames, 1) + ) + self.run_panel.progress_bar.setValue( + result.dynamics_result.analyzed_frames + ) + self.run_panel.progress_label.setText( + f"Progress: completed {result.dynamics_result.analyzed_frames} frames" + ) + self.run_panel.append_log( + "clusterdynamicsml complete.\n" + f"Observed node counts: {result.preview.observed_node_counts}\n" + f"Predicted candidates: {len(result.predictions)}\n" + f"Max predicted node count: {result.max_predicted_node_count}" + ) + self._populate_summary_box(result) + self._populate_lifetime_table(result) + autosaved = self._autosave_project_result(result) + if autosaved is not None: + self._refresh_project_history_view( + select_dataset=autosaved.dataset_file + ) + self.run_panel.append_log( + "Cached clusterdynamicsml result bundle in project folder.\n" + f"Dataset: {autosaved.dataset_file}\n" + f"Files written: {len(autosaved.written_files)}" + ) + self.statusBar().showMessage( + "clusterdynamicsml analysis complete and cached in project" + ) + else: + self._refresh_project_history_view( + select_dataset=self._last_dataset_file + ) + self.statusBar().showMessage("clusterdynamicsml analysis complete") + + def _on_run_failed(self, message: str) -> None: + self.run_panel.progress_bar.setRange(0, 1) + self.run_panel.progress_bar.setValue(0) + self.run_panel.progress_label.setText("Progress: failed") + self.statusBar().showMessage("clusterdynamicsml failed") + self._handle_error("clusterdynamicsml failed", message) + + def _cleanup_run_thread(self) -> None: + self._run_worker = None + self._run_thread = None + + def save_dataset(self) -> None: + if self._last_result is None: + self._show_error( + "Run an analysis or load a saved result before exporting." + ) + return + default_path = self._default_dataset_file() + path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Save clusterdynamicsml dataset", + str(default_path), + "JSON Files (*.json);;All Files (*)", + ) + if not path: + return + saved = save_cluster_dynamicsai_dataset( + self._last_result, + path, + analysis_settings=self._analysis_settings_payload(), + ) + self._last_dataset_file = saved.dataset_file + self.run_panel.append_log( + "Saved clusterdynamicsml dataset to " + f"{saved.dataset_file}\n" + f"Wrote {len(saved.written_files)} file(s)." + ) + self.statusBar().showMessage( + f"Saved clusterdynamicsml dataset to {saved.dataset_file}" + ) + + def load_dataset(self) -> None: + default_path = ( + self._last_dataset_file + if self._last_dataset_file is not None + else self._default_dataset_file() + ) + path, _selected_filter = QFileDialog.getOpenFileName( + self, + "Load clusterdynamicsml dataset", + str(default_path), + "JSON Files (*.json);;All Files (*)", + ) + if not path: + return + loaded = load_cluster_dynamicsai_dataset(path) + self._apply_loaded_dataset( + loaded, + announce=True, + action_label="Loaded", + ) + + def save_colormap_data(self) -> None: + if self._last_result is None: + self._show_error( + "Run an analysis or load a saved result before exporting." + ) + return + default_path = self._default_export_file("cluster_distribution") + path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Save colormap data", + str(default_path), + "CSV Files (*.csv);;All Files (*)", + ) + if not path: + return + display_mode = ( + self.dynamics_plot_panel.display_mode_combo.currentData() + ) + time_unit = self.dynamics_plot_panel.time_unit_combo.currentData() + saved_path = export_cluster_dynamics_colormap_csv( + self._last_result.dynamics_result, + path, + display_mode=( + "fraction" if display_mode is None else str(display_mode) + ), + time_unit="fs" if time_unit is None else str(time_unit), + ) + dynamics_result = self._last_result.dynamics_result + row_count = ( + len(dynamics_result.cluster_labels) * dynamics_result.bin_count + ) + self.run_panel.append_log( + "Saved clusterdynamicsml colormap data to " + f"{saved_path}\n" + f"Rows written: {row_count}" + ) + self.statusBar().showMessage(f"Saved colormap data to {saved_path}") + + def save_lifetime_table(self) -> None: + if self._last_result is None: + self._show_error( + "Run an analysis or load a saved result before exporting." + ) + return + default_path = self._default_export_file("lifetime") + path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Save lifetime table", + str(default_path), + "CSV Files (*.csv);;All Files (*)", + ) + if not path: + return + if ( + self.lifetime_table.rowCount() == 0 + and self._last_result is not None + ): + self._populate_lifetime_table(self._last_result) + saved_path = _write_table_widget_csv(self.lifetime_table, Path(path)) + row_count = self.lifetime_table.rowCount() + self.run_panel.append_log( + "Saved clusterdynamicsml lifetime table to " + f"{saved_path}\n" + f"Rows written: {row_count}" + ) + self.statusBar().showMessage(f"Saved lifetime table to {saved_path}") + + def save_powerpoint_report(self) -> None: + if self._last_result is None: + self._show_error( + "Run an analysis or load a saved result before exporting." + ) + return + + self.dynamics_plot_panel.set_result(self._last_result.dynamics_result) + self.saxs_panel.set_result(self._last_result) + selection_summary = self.run_panel.selection_box.toPlainText().strip() + if not selection_summary: + selection_summary = self._format_preview_text( + self._last_result.preview + ) + summary_text = self.summary_box.toPlainText().strip() + if not summary_text: + self._populate_summary_box(self._last_result) + summary_text = self.summary_box.toPlainText().strip() + + default_path = self._default_powerpoint_report_file() + path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Save clusterdynamicsml PowerPoint report", + str(default_path), + "PowerPoint Files (*.pptx);;All Files (*)", + ) + if not path: + return + + self.run_panel.progress_label.setText( + "Progress: generating PowerPoint report" + ) + self.run_panel.progress_bar.setRange(0, 1) + self.run_panel.progress_bar.setValue(0) + self.run_panel.progress_bar.setFormat("%v / %m steps") + try: + export_result = export_cluster_dynamicsai_report_pptx( + result=self._last_result, + selection_summary=selection_summary, + result_summary=summary_text, + dynamics_figure=self.dynamics_plot_panel.figure, + surrogate_figure=self.saxs_panel.figure, + output_path=path, + settings=self._powerpoint_export_settings(), + project_dir=self.dataset_panel.project_dir(), + frames_dir=self.trajectory_panel.get_frames_dir(), + progress_callback=self._on_powerpoint_report_progress, + ) + except Exception as exc: + self.run_panel.progress_bar.setRange(0, 1) + self.run_panel.progress_bar.setValue(0) + self.run_panel.progress_bar.setFormat("%v / %m steps") + self.run_panel.progress_label.setText( + "Progress: PowerPoint export failed" + ) + self._handle_error( + "clusterdynamicsml PowerPoint export failed", str(exc) + ) + return + + self.run_panel.progress_label.setText( + "Progress: PowerPoint report saved" + ) + self.run_panel.progress_bar.setValue( + self.run_panel.progress_bar.maximum() + ) + self.run_panel.progress_bar.setFormat("%v / %m steps") + if export_result.appended_to_existing: + self.run_panel.append_log( + "Appended clusterdynamicsml report slides to " + f"{export_result.report_path}\n" + f"Slides added: {export_result.added_slide_count}" + ) + else: + self.run_panel.append_log( + "Saved clusterdynamicsml PowerPoint report to " + f"{export_result.report_path}\n" + f"Slides written: {export_result.added_slide_count}" + ) + self.statusBar().showMessage( + f"Saved PowerPoint report to {export_result.report_path}" + ) + + def _inspect_frames_dir(self, frames_dir: Path, *, announce: bool) -> None: + self._last_summary = None + try: + analyzer = ExtractedFrameFolderClusterAnalyzer( + frames_dir=frames_dir, + atom_type_definitions={}, + pair_cutoffs_def={}, + ) + self._last_summary = analyzer.inspect() + self._sync_box_dimensions_from_summary(self._last_summary) + self._set_frame_format(self._last_summary.get("frame_format")) + self.trajectory_panel.set_summary(self._last_summary) + if announce: + self.run_panel.append_log( + "Inspection complete. " + f"Detected {self._last_summary['n_frames']} extracted frame(s)." + ) + self.statusBar().showMessage("Inspection complete") + except ValueError as exc: + self._sync_box_dimensions_from_summary(None) + frame_format, detail = self._detect_frame_format(frames_dir) + self._set_frame_format(frame_format) + self.trajectory_panel.set_summary_text(str(exc)) + if detail is not None: + self.trajectory_panel.set_frame_mode(None, detail=detail) + if announce: + self._handle_error("Frames-folder inspection failed", str(exc)) + self._refresh_selection_preview() + + def _on_frames_dir_changed(self, frames_dir: Path | None) -> None: + if self._suspend_preview_refresh: + return + self._last_summary = None + self._last_result = None + self.dynamics_plot_panel.set_result(None) + self.histogram_panel.set_result(None) + self.saxs_panel.set_result(None) + self.summary_box.clear() + self.lifetime_table.setRowCount(0) + self.time_panel.set_folder_start_time_fs(None, emit_signal=False) + if frames_dir is None: + self._sync_box_dimensions_from_summary(None) + self._set_frame_format(None) + self.trajectory_panel.set_summary_text("") + self._refresh_selection_preview() + return + self._inspect_frames_dir(frames_dir, announce=False) + + def _on_project_dir_changed(self) -> None: + if self._restoring_project_dataset: + return + self._sync_project_defaults() + if not self._initializing: + restored = self._restore_latest_project_result(announce=True) + if not restored: + self._refresh_project_history_view() + self._refresh_selection_preview() + + def _sync_project_defaults(self) -> None: + if self._suspend_preview_refresh: + return + project_dir = self.dataset_panel.project_dir() + if project_dir is None: + return + project_file = build_project_paths(project_dir).project_file + if not project_file.is_file(): + return + try: + settings = self._project_manager.load_project(project_dir) + except Exception: + return + if self.prediction_panel.clusters_dir() is None: + self.prediction_panel.set_clusters_dir( + settings.resolved_clusters_dir, + emit_signal=False, + ) + if self.prediction_panel.experimental_data_file() is None: + self.prediction_panel.set_experimental_data_file( + settings.resolved_experimental_data_path, + emit_signal=False, + ) + + def _refresh_selection_preview(self) -> None: + if self._suspend_preview_refresh: + return + frames_dir = self.trajectory_panel.get_frames_dir() + if frames_dir is None: + self.run_panel.set_selection_summary( + "Select an extracted frames folder to preview the " + "clusterdynamicsml workflow." + ) + return + try: + preview = self._build_preview_workflow().preview_selection() + if ( + self.time_panel.folder_start_time_fs() is None + and preview.dynamics_preview.folder_start_time_fs is not None + and preview.dynamics_preview.folder_start_time_source + != "manual field" + ): + self.time_panel.set_folder_start_time_fs( + preview.dynamics_preview.folder_start_time_fs, + emit_signal=False, + ) + text = self._format_preview_text(preview) + except Exception as exc: + text = ( + "Adjust the current settings to preview the extrapolation " + f"workflow.\nValidation warning: {exc}" + ) + self.run_panel.set_selection_summary(text) + + def _build_preview_workflow(self) -> ClusterDynamicsMLWorkflow: + frames_dir = self.trajectory_panel.get_frames_dir() + if frames_dir is None: + raise ValueError("No extracted frames folder selected.") + manual_box_dimensions = self.definitions_panel.box_dimensions() + resolved_box_dimensions = manual_box_dimensions + if ( + self.definitions_panel.use_pbc() + and resolved_box_dimensions is None + ): + resolved_box_dimensions = self._detected_box_dimensions() + return ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=self.definitions_panel.atom_type_definitions(), + pair_cutoff_definitions=self.definitions_panel.pair_cutoff_definitions(), + clusters_dir=self.prediction_panel.clusters_dir(), + project_dir=self.dataset_panel.project_dir(), + experimental_data_file=self.prediction_panel.experimental_data_file(), + box_dimensions=resolved_box_dimensions, + use_pbc=self.definitions_panel.use_pbc(), + default_cutoff=self.definitions_panel.default_cutoff(), + shell_levels=self.definitions_panel.shell_growth_levels(), + shared_shells=self.definitions_panel.shared_shells(), + include_shell_atoms_in_stoichiometry=( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + search_mode=self.definitions_panel.search_mode(), + folder_start_time_fs=self.time_panel.folder_start_time_fs(), + first_frame_time_fs=self.time_panel.first_frame_time_fs(), + frame_timestep_fs=self.time_panel.frame_timestep_fs(), + frames_per_colormap_timestep=( + self.time_panel.frames_per_colormap_timestep() + ), + analysis_start_fs=self.time_panel.analysis_start_fs(), + analysis_stop_fs=self.time_panel.analysis_stop_fs(), + energy_file=self.run_panel.energy_file(), + target_node_counts=self.prediction_panel.target_node_counts(), + candidates_per_size=self.prediction_panel.candidates_per_size(), + prediction_population_share_threshold=( + self.prediction_panel.prediction_population_share_threshold() + ), + q_min=self.prediction_panel.q_min(), + q_max=self.prediction_panel.q_max(), + q_points=self.prediction_panel.q_points(), + ) + + def _format_preview_text(self, preview) -> str: + dynamics_preview = preview.dynamics_preview + lines = [ + f"Mode: {frame_folder_label(dynamics_preview.frame_format)}", + f"PBC: {'on' if dynamics_preview.use_pbc else 'off'}", + "Search mode: " + f"{format_search_mode_label(self.definitions_panel.search_mode())}", + f"Frames in folder: {dynamics_preview.total_frames}", + f"Frames selected: {dynamics_preview.selected_frames}", + f"Frame timestep: {dynamics_preview.frame_timestep_fs:.3f} fs", + "Frames per colormap timestep: " + f"{dynamics_preview.frames_per_colormap_timestep}", + f"Colormap timestep: {dynamics_preview.colormap_timestep_fs:.3f} fs", + f"Time bins: {dynamics_preview.bin_count}", + f"Smaller-cluster labels with structures: {preview.structure_label_count}", + f"Structure files discovered: {preview.total_structure_files}", + f"Observed node counts: {preview.observed_node_counts or ('n/a',)}", + f"Target node counts: {preview.target_node_counts or ('n/a',)}", + "Stoichiometry bins: " + + ( + "solute + shell atoms" + if self.definitions_panel.include_shell_atoms_in_stoichiometry() + else "solute only" + ), + "Resolved box dimensions: " + f"{format_box_dimensions(dynamics_preview.resolved_box_dimensions)}", + ( + "Clusters folder: " + + ( + "not set" + if preview.clusters_dir is None + else str(preview.clusters_dir) + ) + ), + ( + "Experimental data: " + + ( + "not set" + if preview.experimental_data_path is None + else str(preview.experimental_data_path) + ) + ), + ] + if dynamics_preview.first_selected_frame is not None: + lines.append( + "Frame file range: " + f"{dynamics_preview.first_selected_frame} to " + f"{dynamics_preview.last_selected_frame}" + ) + if dynamics_preview.time_warnings: + lines.extend( + f"Warning: {message}" + for message in dynamics_preview.time_warnings + ) + if preview.warnings: + lines.extend(f"Warning: {message}" for message in preview.warnings) + return "\n".join(lines) + + def _populate_summary_box(self, result: ClusterDynamicsMLResult) -> None: + lines = [ + f"Frames analyzed: {result.dynamics_result.analyzed_frames}", + f"Time bins: {result.dynamics_result.bin_count}", + f"Observed node counts: {result.preview.observed_node_counts}", + f"Target node counts: {result.preview.target_node_counts}", + f"Predicted candidates: {len(result.predictions)}", + f"Max observed node count: {result.max_observed_node_count}", + ( + "Max predicted node count above threshold: " + f"{result.max_predicted_node_count}" + ), + ( + "Prediction-candidate share threshold: " + f"{result.prediction_population_share_threshold:.3f}" + ), + ( + "Store prediction history in project: " + f"{'on' if self.prediction_panel.store_prediction_history() else 'off'}" + ), + ( + "Clusters folder: " + + ( + "n/a" + if result.preview.clusters_dir is None + else str(result.preview.clusters_dir) + ) + ), + ( + "Experimental data: " + + ( + "n/a" + if result.preview.experimental_data_path is None + else str(result.preview.experimental_data_path) + ) + ), + ] + if result.saxs_comparison is not None: + lines.append( + f"SAXS components in mixture: " + f"{len(result.saxs_comparison.component_weights)}" + ) + if result.saxs_comparison.component_output_dir is not None: + lines.append( + "SAXS component files: " + f"{result.saxs_comparison.component_output_dir}" + ) + if result.saxs_comparison.surrogate_structure_dir is not None: + lines.append( + "Surrogate XYZ files: " + f"{result.saxs_comparison.surrogate_structure_dir}" + ) + if result.saxs_comparison.rmse is not None: + lines.append( + f"Cluster-only surrogate SAXS RMSE: " + f"{result.saxs_comparison.rmse:.6g}" + ) + self.summary_box.setPlainText("\n".join(lines)) + + def _populate_lifetime_table( + self, result: ClusterDynamicsMLResult + ) -> None: + self.lifetime_table.setSortingEnabled(False) + lifetime_rows = _combined_model_weight_rows(result) + self.lifetime_table.setRowCount(len(lifetime_rows)) + for row, entry in enumerate(lifetime_rows): + values = ( + str(entry["type"]), + str(entry["nodes"]), + str(entry["size_rank"]), + str(entry["candidate_rank"]), + str(entry["label"]), + _format_optional_percent( + entry["observed_only_normalized_weight"] + ), + _format_optional_percent(entry["normalized_weight"]), + _format_optional_percent(entry["predicted_population_share"]), + _format_optional_float(entry["mean_lifetime_fs"]), + _format_optional_float(entry["std_lifetime_fs"]), + _format_optional_int(entry["completed_lifetime_count"]), + _format_optional_int(entry["window_truncated_lifetime_count"]), + f"{float(entry['association_rate_per_ps']):.3f}", + f"{float(entry['dissociation_rate_per_ps']):.3f}", + f"{float(entry['occupancy_fraction']) * 100.0:.1f}", + f"{float(entry['mean_count_per_frame']):.3f}", + str(entry["reference"]), + str(entry["notes"]), + ) + for column, value in enumerate(values): + self.lifetime_table.setItem( + row, + column, + QTableWidgetItem(value), + ) + self.lifetime_table.resizeColumnsToContents() + self.lifetime_table.setSortingEnabled(True) + + def _refresh_project_history_view( + self, + *, + select_dataset: Path | None = None, + ) -> None: + project_dir = self.dataset_panel.project_dir() + self._project_history_entries = ( + self._project_history_entries_for_project(project_dir) + ) + selected_dataset = ( + None + if select_dataset is None + else Path(select_dataset).expanduser().resolve() + ) + if selected_dataset is None and self._last_dataset_file is not None: + selected_dataset = self._last_dataset_file.resolve() + + self.history_table.setSortingEnabled(False) + self.history_table.setRowCount(len(self._project_history_entries)) + selected_row: int | None = None + for row, entry in enumerate(self._project_history_entries): + is_loaded = ( + self._last_dataset_file is not None + and self._last_dataset_file.resolve() == entry.dataset_file + ) + values = ( + "Yes" if is_loaded else "", + entry.saved_label, + _format_int_sequence(entry.target_node_counts), + str(entry.candidates_per_size), + f"{entry.prediction_population_share_threshold:.3f}", + str(entry.prediction_count), + _format_optional_int(entry.max_predicted_node_count), + _format_optional_float(entry.rmse), + entry.dataset_file.name, + ) + for column, value in enumerate(values): + item = QTableWidgetItem(value) + item.setData(Qt.ItemDataRole.UserRole, str(entry.dataset_file)) + item.setToolTip(str(entry.dataset_file)) + self.history_table.setItem(row, column, item) + if ( + selected_dataset is not None + and entry.dataset_file == selected_dataset + ): + selected_row = row + + self.history_table.resizeColumnsToContents() + self.history_table.setSortingEnabled(False) + self.history_table.clearSelection() + if self._project_history_entries: + target_row = 0 if selected_row is None else selected_row + self.history_table.selectRow(target_row) + self._update_history_controls() + + def _project_history_entries_for_project( + self, + project_dir: Path | None, + ) -> list[_ProjectPredictionHistoryEntry]: + saved_results_dir = self._project_saved_results_dir(project_dir) + if saved_results_dir is None or not saved_results_dir.is_dir(): + return [] + entries: list[_ProjectPredictionHistoryEntry] = [] + for dataset_file in saved_results_dir.rglob( + "*_clusterdynamicsml.json" + ): + if not dataset_file.is_file(): + continue + entry = _read_project_history_entry(dataset_file) + if entry is not None: + entries.append(entry) + entries.sort( + key=lambda entry: ( + -float(entry.modified_time), + entry.dataset_file.name.lower(), + ) + ) + return entries + + def _selected_history_dataset_file(self) -> Path | None: + selected_ranges = self.history_table.selectedRanges() + if not selected_ranges: + return None + return self._history_dataset_file_for_row(selected_ranges[0].topRow()) + + def _history_dataset_file_for_row(self, row: int) -> Path | None: + if row < 0 or row >= self.history_table.rowCount(): + return None + item = self.history_table.item(row, 0) + if item is None: + return None + value = item.data(Qt.ItemDataRole.UserRole) + return None if value in {None, ""} else Path(str(value)) + + def _update_history_controls(self) -> None: + selected_dataset = self._selected_history_dataset_file() + has_history = bool(self._project_history_entries) + self.history_table.setEnabled(has_history) + self.history_load_button.setEnabled(selected_dataset is not None) + if not has_history: + project_dir = self.dataset_panel.project_dir() + if project_dir is None: + self.history_status_label.setText( + "Select a project folder to compare saved prediction runs. " + "The most recent run is plotted by default." + ) + else: + self.history_status_label.setText( + "No cached prediction history was found for the current project." + ) + return + if selected_dataset is None: + self.history_status_label.setText( + "Select a saved prediction run to compare its parameters and plots." + ) + return + if ( + self._last_dataset_file is not None + and self._last_dataset_file.resolve() == selected_dataset.resolve() + ): + self.history_status_label.setText( + "The selected history entry is currently plotted." + ) + return + self.history_status_label.setText( + "Select a row and click Plot Selected Prediction to compare it." + ) + + def _load_selected_history_entry(self) -> None: + dataset_file = self._selected_history_dataset_file() + if dataset_file is None: + return + try: + loaded = load_cluster_dynamicsai_dataset(dataset_file) + except Exception as exc: + self._handle_error( + "clusterdynamicsml history load failed", + f"Could not load {dataset_file}: {exc}", + ) + return + self._apply_loaded_dataset( + loaded, + announce=True, + action_label="Loaded history entry", + ) + + @staticmethod + def _build_table(headers: tuple[str, ...]) -> QTableWidget: + table = QTableWidget(0, len(headers)) + table.setHorizontalHeaderLabels(list(headers)) + table.verticalHeader().setVisible(False) + table.setAlternatingRowColors(True) + header = table.horizontalHeader() + header.setSectionResizeMode(QHeaderView.ResizeMode.ResizeToContents) + if headers: + header.setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) + return table + + def _default_dataset_dir(self) -> Path: + project_dir = self.dataset_panel.project_dir() + if project_dir is not None: + paths = build_project_paths(project_dir) + target_dir = paths.exported_data_dir / "clusterdynamicsml" + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir + frames_dir = self.trajectory_panel.get_frames_dir() + if frames_dir is not None: + return frames_dir.parent + return Path.cwd() + + def _default_dataset_file(self) -> Path: + if self._last_dataset_file is not None: + return self._last_dataset_file + frames_dir = self.trajectory_panel.get_frames_dir() + folder_label = "cluster_dynamics_ml" + if frames_dir is not None: + folder_label = frames_dir.name or folder_label + return ( + self._default_dataset_dir() + / f"{folder_label}_clusterdynamicsml.json" + ) + + def _default_export_file(self, suffix_label: str) -> Path: + dataset_file = self._default_dataset_file() + return dataset_file.with_name( + f"{dataset_file.stem}_{suffix_label}.csv" + ) + + def _default_powerpoint_report_file(self) -> Path: + frames_dir = self.trajectory_panel.get_frames_dir() + fallback_label = "cluster_dynamics_ml_report" + if frames_dir is not None: + fallback_label = ( + f"{frames_dir.name or 'cluster_dynamics_ml'}_report" + ) + return default_powerpoint_report_path( + project_dir=self.dataset_panel.project_dir(), + fallback_dir=self._default_dataset_dir(), + fallback_stem=fallback_label, + ) + + def _project_saved_results_dir( + self, + project_dir: Path | None = None, + ) -> Path | None: + resolved_project_dir = ( + self.dataset_panel.project_dir() + if project_dir is None + else Path(project_dir).expanduser().resolve() + ) + if resolved_project_dir is None: + return None + saved_results_dir = ( + build_project_paths(resolved_project_dir).exported_data_dir + / "clusterdynamicsml" + / "saved_results" + ) + saved_results_dir.mkdir(parents=True, exist_ok=True) + return saved_results_dir + + def _latest_project_dataset_file( + self, + project_dir: Path | None = None, + ) -> Path | None: + saved_results_dir = self._project_saved_results_dir(project_dir) + if saved_results_dir is None or not saved_results_dir.is_dir(): + return None + candidates: list[Path] = [] + for candidate in saved_results_dir.rglob("*_clusterdynamicsml.json"): + if candidate.is_file(): + candidates.append(candidate) + if not candidates: + return None + return max( + candidates, + key=lambda path: (path.stat().st_mtime, path.name.lower()), + ) + + def _autosave_project_result( + self, + result: ClusterDynamicsMLResult, + ) -> SavedClusterDynamicsMLDataset | None: + project_dir = self.dataset_panel.project_dir() + if project_dir is None: + return None + saved_results_dir = self._project_saved_results_dir(project_dir) + if saved_results_dir is None: + return None + frames_dir = self.trajectory_panel.get_frames_dir() + stem_label = _safe_filename_stem( + project_dir.name + if frames_dir is None + else (frames_dir.name or project_dir.name) + ) + if self.prediction_panel.store_prediction_history(): + bundle_dir = saved_results_dir / ( + datetime.now().strftime("%Y%m%d_%H%M%S_%f") + f"_{stem_label}" + ) + else: + bundle_dir = saved_results_dir / f"latest_{stem_label}" + if bundle_dir.exists(): + shutil.rmtree(bundle_dir) + bundle_dir.mkdir(parents=True, exist_ok=True) + dataset_path = bundle_dir / f"{stem_label}_clusterdynamicsml.json" + saved = save_cluster_dynamicsai_dataset( + result, + dataset_path, + analysis_settings=self._analysis_settings_payload(), + ) + + selection_summary = self.run_panel.selection_box.toPlainText().strip() + if not selection_summary: + selection_summary = self._format_preview_text(result.preview) + summary_text = self.summary_box.toPlainText().strip() + if not summary_text: + self._populate_summary_box(result) + summary_text = self.summary_box.toPlainText().strip() + extra_files = [ + ( + saved.dataset_file.with_name( + f"{saved.dataset_file.stem}_selection_preview.txt" + ), + selection_summary, + ), + ( + saved.dataset_file.with_name( + f"{saved.dataset_file.stem}_summary.txt" + ), + summary_text, + ), + ] + written_files = list(saved.written_files) + for output_path, content in extra_files: + output_path.write_text(content.strip() + "\n", encoding="utf-8") + written_files.append(output_path) + + autosaved = SavedClusterDynamicsMLDataset( + dataset_file=saved.dataset_file, + written_files=tuple(written_files), + ) + self._last_dataset_file = autosaved.dataset_file + return autosaved + + def _restore_latest_project_result(self, *, announce: bool) -> bool: + project_dir = self.dataset_panel.project_dir() + if project_dir is None: + return False + dataset_file = self._latest_project_dataset_file(project_dir) + if dataset_file is None: + return False + try: + loaded = load_cluster_dynamicsai_dataset(dataset_file) + except Exception as exc: + if announce: + self.run_panel.append_log( + "Could not restore cached clusterdynamicsml result from " + f"{dataset_file}: {exc}" + ) + return False + self._apply_loaded_dataset( + loaded, + announce=announce, + action_label="Restored", + ) + return True + + def _apply_loaded_dataset( + self, + loaded: LoadedClusterDynamicsMLDataset, + *, + announce: bool, + action_label: str, + ) -> None: + self._restoring_project_dataset = True + try: + self._last_dataset_file = loaded.dataset_file + self._apply_analysis_settings(loaded.analysis_settings) + self._last_result = loaded.result + self.dynamics_plot_panel.set_result(loaded.result.dynamics_result) + self.histogram_panel.set_result(loaded.result) + self.saxs_panel.set_result(loaded.result) + self.run_panel.set_selection_summary( + self._format_preview_text(loaded.result.preview) + ) + self._populate_summary_box(loaded.result) + self._populate_lifetime_table(loaded.result) + analyzed_frames = max( + loaded.result.dynamics_result.analyzed_frames, 1 + ) + self.run_panel.progress_bar.setRange(0, analyzed_frames) + self.run_panel.progress_bar.setValue(analyzed_frames) + self.run_panel.progress_label.setText( + f"Progress: loaded saved result ({analyzed_frames} frames)" + ) + finally: + self._restoring_project_dataset = False + self._refresh_project_history_view(select_dataset=loaded.dataset_file) + if announce: + self.run_panel.append_log( + f"{action_label} clusterdynamicsml dataset from " + f"{loaded.dataset_file}\n" + f"Predicted candidates: {len(loaded.result.predictions)}" + ) + self.statusBar().showMessage( + f"{action_label} clusterdynamicsml dataset from {loaded.dataset_file}" + ) + + def _powerpoint_export_settings(self) -> PowerPointExportSettings: + project_dir = self.dataset_panel.project_dir() + if project_dir is None: + return PowerPointExportSettings() + project_file = build_project_paths(project_dir).project_file + if not project_file.is_file(): + return PowerPointExportSettings() + try: + settings = self._project_manager.load_project(project_dir) + except Exception: + return PowerPointExportSettings() + return PowerPointExportSettings.from_dict( + settings.powerpoint_export_settings.to_dict() + ) + + def _on_powerpoint_report_progress( + self, + processed: int, + total: int, + message: str, + ) -> None: + total_steps = max(int(total), 1) + processed_steps = max(0, min(int(processed), total_steps)) + self.run_panel.progress_label.setText( + f"Progress: PowerPoint report {processed_steps}/{total_steps}" + ) + self.run_panel.progress_bar.setRange(0, total_steps) + self.run_panel.progress_bar.setValue(processed_steps) + self.run_panel.progress_bar.setFormat("%v / %m steps") + self.statusBar().showMessage(message) + QApplication.processEvents() + + def _analysis_settings_payload(self) -> dict[str, object]: + frames_dir = self.trajectory_panel.get_frames_dir() + energy_file = self.run_panel.energy_file() + project_dir = self.dataset_panel.project_dir() + return { + "frames_dir": None if frames_dir is None else str(frames_dir), + "clusters_dir": ( + None + if self.prediction_panel.clusters_dir() is None + else str(self.prediction_panel.clusters_dir()) + ), + "experimental_data_file": ( + None + if self.prediction_panel.experimental_data_file() is None + else str(self.prediction_panel.experimental_data_file()) + ), + "energy_file": None if energy_file is None else str(energy_file), + "project_dir": None if project_dir is None else str(project_dir), + "atom_type_definitions": { + atom_type: [ + [element, residue] for element, residue in criteria + ] + for atom_type, criteria in self.definitions_panel.atom_type_definitions().items() + }, + "pair_cutoff_definitions": [ + { + "atom1": atom1, + "atom2": atom2, + "shell_cutoffs": { + str(level): float(cutoff) + for level, cutoff in shell_cutoffs.items() + }, + } + for (atom1, atom2), shell_cutoffs in sorted( + self.definitions_panel.pair_cutoff_definitions().items() + ) + ], + "box_dimensions": ( + None + if self.definitions_panel.box_dimensions() is None + else list(self.definitions_panel.box_dimensions()) + ), + "use_pbc": self.definitions_panel.use_pbc(), + "default_cutoff": self.definitions_panel.default_cutoff(), + "shell_levels": list(self.definitions_panel.shell_growth_levels()), + "shared_shells": self.definitions_panel.shared_shells(), + "include_shell_atoms_in_stoichiometry": ( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + "search_mode": self.definitions_panel.search_mode(), + "folder_start_time_fs": self.time_panel.folder_start_time_fs(), + "first_frame_time_fs": self.time_panel.first_frame_time_fs(), + "frame_timestep_fs": self.time_panel.frame_timestep_fs(), + "frames_per_colormap_timestep": ( + self.time_panel.frames_per_colormap_timestep() + ), + "analysis_start_fs": self.time_panel.analysis_start_fs(), + "analysis_stop_fs": self.time_panel.analysis_stop_fs(), + "target_node_counts": list( + self.prediction_panel.target_node_counts() + ), + "candidates_per_size": self.prediction_panel.candidates_per_size(), + "prediction_population_share_threshold": ( + self.prediction_panel.prediction_population_share_threshold() + ), + "store_prediction_history": ( + self.prediction_panel.store_prediction_history() + ), + "q_min": self.prediction_panel.q_min(), + "q_max": self.prediction_panel.q_max(), + "q_points": self.prediction_panel.q_points(), + } + + def _apply_analysis_settings(self, payload: dict[str, object]) -> None: + self._suspend_preview_refresh = True + try: + self.trajectory_panel.frames_dir_edit.setText( + "" + if _optional_path(payload.get("frames_dir")) is None + else str(_optional_path(payload.get("frames_dir"))) + ) + self.prediction_panel.set_clusters_dir( + _optional_path(payload.get("clusters_dir")), + emit_signal=False, + ) + self.prediction_panel.set_experimental_data_file( + _optional_path(payload.get("experimental_data_file")), + emit_signal=False, + ) + self.run_panel.energy_path_edit.setText( + "" + if _optional_path(payload.get("energy_file")) is None + else str(_optional_path(payload.get("energy_file"))) + ) + self.dataset_panel.set_project_dir( + _optional_path(payload.get("project_dir")) + ) + atom_type_definitions = { + str(atom_type): [ + ( + str(entry[0]), + ( + None + if len(entry) < 2 or entry[1] in {None, ""} + else str(entry[1]) + ), + ) + for entry in criteria + if isinstance(entry, (list, tuple)) and entry + ] + for atom_type, criteria in dict( + payload.get("atom_type_definitions", {}) + ).items() + } + pair_cutoff_definitions = { + (str(entry.get("atom1", "")), str(entry.get("atom2", ""))): { + int(level): float(cutoff) + for level, cutoff in dict( + entry.get("shell_cutoffs", {}) + ).items() + } + for entry in payload.get("pair_cutoff_definitions", []) + if isinstance(entry, dict) + } + self.definitions_panel.load_atom_type_definitions( + atom_type_definitions, + emit_signal=False, + ) + self.definitions_panel.load_pair_cutoff_definitions( + pair_cutoff_definitions, + emit_signal=False, + ) + self.definitions_panel.set_box_dimensions( + _optional_box_dimensions(payload.get("box_dimensions")), + emit_signal=False, + ) + self.definitions_panel.set_use_pbc( + bool(payload.get("use_pbc", False)), + emit_signal=False, + ) + self.definitions_panel.set_default_cutoff( + _optional_float(payload.get("default_cutoff")), + emit_signal=False, + ) + self.definitions_panel.set_shell_growth_levels( + tuple(int(value) for value in payload.get("shell_levels", [])), + emit_signal=False, + ) + self.definitions_panel.set_shared_shells( + bool(payload.get("shared_shells", False)), + emit_signal=False, + ) + self.definitions_panel.set_include_shell_atoms_in_stoichiometry( + bool( + payload.get("include_shell_atoms_in_stoichiometry", False) + ), + emit_signal=False, + ) + self.definitions_panel.set_search_mode( + str(payload.get("search_mode", "kdtree")), + emit_signal=False, + ) + self.time_panel.set_folder_start_time_fs( + _optional_float(payload.get("folder_start_time_fs")), + emit_signal=False, + ) + self.time_panel.set_first_frame_time_fs( + float(payload.get("first_frame_time_fs", 0.0)), + emit_signal=False, + ) + self.time_panel.set_frame_timestep_fs( + float(payload.get("frame_timestep_fs", 0.5)), + emit_signal=False, + ) + self.time_panel.set_frames_per_colormap_timestep( + int(payload.get("frames_per_colormap_timestep", 1)), + emit_signal=False, + ) + self.time_panel.set_analysis_start_fs( + _optional_float(payload.get("analysis_start_fs")), + emit_signal=False, + ) + self.time_panel.set_analysis_stop_fs( + _optional_float(payload.get("analysis_stop_fs")), + emit_signal=False, + ) + self.prediction_panel.set_target_node_counts( + tuple( + int(value) + for value in payload.get("target_node_counts", []) + ), + emit_signal=False, + ) + self.prediction_panel.set_candidates_per_size( + int(payload.get("candidates_per_size", 3)), + emit_signal=False, + ) + self.prediction_panel.set_prediction_population_share_threshold( + float( + payload.get( + "prediction_population_share_threshold", + 0.02, + ) + ), + emit_signal=False, + ) + self.prediction_panel.set_store_prediction_history( + bool(payload.get("store_prediction_history", True)), + emit_signal=False, + ) + self.prediction_panel.set_q_settings( + q_min=float(payload.get("q_min", 0.02)), + q_max=float(payload.get("q_max", 1.20)), + q_points=int(payload.get("q_points", 250)), + emit_signal=False, + ) + finally: + self._suspend_preview_refresh = False + self._refresh_selection_preview() + + def _detected_box_dimensions(self) -> tuple[float, float, float] | None: + if self._last_summary is None: + return None + value = self._last_summary.get("box_dimensions") + if value is None: + value = self._last_summary.get("estimated_box_dimensions") + if value is None: + return None + return tuple(float(component) for component in value) + + def _sync_box_dimensions_from_summary( + self, + summary: dict[str, object] | None, + ) -> None: + if summary is None: + self.definitions_panel.set_box_dimensions(None, emit_signal=False) + return + if summary.get("box_dimensions_source_kind") == "source_filename": + value = summary.get("box_dimensions") + if value is not None: + self.definitions_panel.set_box_dimensions( + tuple(float(component) for component in value), + emit_signal=False, + ) + return + self.definitions_panel.set_box_dimensions(None, emit_signal=False) + + def _set_frame_format(self, frame_format: object | None) -> None: + normalized = None if frame_format is None else str(frame_format) + self._frame_format = normalized + self.trajectory_panel.set_frame_mode(normalized) + self.definitions_panel.set_frame_mode(normalized) + + def _detect_frame_format( + self, + frames_dir: Path | None, + ) -> tuple[str | None, str | None]: + if frames_dir is None: + return None, None + try: + frame_format, _frame_paths = detect_frame_folder_mode(frames_dir) + except ValueError as exc: + return None, str(exc) + return frame_format, None + + def _handle_error(self, title: str, message: str) -> None: + self.run_panel.append_log(f"{title}: {message}") + QMessageBox.critical(self, title, message) + + def _show_error(self, message: str) -> None: + QMessageBox.critical(self, "Error", message) + + @staticmethod + def _wrap_scroll_area(widget: QWidget) -> QScrollArea: + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setWidget(widget) + return scroll_area + + +def _optional_float(value: object) -> float | None: + return None if value is None else float(value) + + +def _format_optional_float(value: float | None) -> str: + return "n/a" if value is None else f"{value:.3f}" + + +def _format_optional_percent(value: object) -> str: + if value is None: + return "n/a" + return f"{float(value) * 100.0:.2f}" + + +def _format_optional_int(value: object) -> str: + return "n/a" if value is None else str(int(value)) + + +def _optional_path(value: object) -> Path | None: + if value is None: + return None + text = str(value).strip() + return None if not text else Path(text) + + +def _optional_box_dimensions( + value: object, +) -> tuple[float, float, float] | None: + if value is None: + return None + components = tuple(float(component) for component in value) + if len(components) != 3: + return None + return components + + +def _safe_filename_stem(value: str) -> str: + cleaned = "".join( + ( + character + if character.isalnum() or character in {".", "_", "-"} + else "_" + ) + for character in str(value).strip() + ) + return cleaned.strip("._") or "clusterdynamicsml" + + +def _format_int_sequence(values: tuple[int, ...] | list[int]) -> str: + sequence = tuple(int(value) for value in values) + if not sequence: + return "n/a" + if len(sequence) == 1: + return str(sequence[0]) + expected = tuple(range(sequence[0], sequence[-1] + 1)) + if sequence == expected: + return f"{sequence[0]}-{sequence[-1]}" + return ",".join(str(value) for value in sequence) + + +def _read_project_history_entry( + dataset_file: Path, +) -> _ProjectPredictionHistoryEntry | None: + try: + payload = json.loads(dataset_file.read_text(encoding="utf-8")) + except Exception: + return None + analysis_settings = dict(payload.get("analysis_settings", {})) + preview_payload = dict(payload.get("preview", {})) + target_node_counts = tuple( + int(value) + for value in ( + analysis_settings.get( + "target_node_counts", + preview_payload.get("target_node_counts", []), + ) + or [] + ) + ) + candidates_per_size = int( + analysis_settings.get("candidates_per_size", 0) or 0 + ) + prediction_count = len(payload.get("predictions", [])) + max_predicted_node_count = payload.get("max_predicted_node_count") + saxs_payload = payload.get("saxs_comparison") + rmse = None + if isinstance(saxs_payload, dict) and saxs_payload.get("rmse") is not None: + rmse = float(saxs_payload["rmse"]) + modified_time = float(dataset_file.stat().st_mtime) + return _ProjectPredictionHistoryEntry( + dataset_file=dataset_file.expanduser().resolve(), + modified_time=modified_time, + saved_label=datetime.fromtimestamp(modified_time).strftime( + "%Y-%m-%d %H:%M:%S" + ), + target_node_counts=target_node_counts, + candidates_per_size=candidates_per_size, + prediction_population_share_threshold=float( + analysis_settings.get( + "prediction_population_share_threshold", 0.02 + ) + ), + prediction_count=prediction_count, + max_predicted_node_count=( + None + if max_predicted_node_count is None + else int(max_predicted_node_count) + ), + rmse=rmse, + ) + + +def _write_table_widget_csv(table: QTableWidget, output_path: Path) -> Path: + resolved_path = output_path.expanduser().resolve() + if resolved_path.suffix.lower() != ".csv": + resolved_path = resolved_path.with_suffix(".csv") + resolved_path.parent.mkdir(parents=True, exist_ok=True) + headers = [ + ( + table.horizontalHeaderItem(column).text() + if table.horizontalHeaderItem(column) is not None + else f"column_{column}" + ) + for column in range(table.columnCount()) + ] + with resolved_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.writer(handle) + writer.writerow(headers) + for row in range(table.rowCount()): + writer.writerow( + [ + ( + "" + if table.item(row, column) is None + else table.item(row, column).text() + ) + for column in range(table.columnCount()) + ] + ) + return resolved_path + + +def _size_rank_map(node_counts) -> dict[int, int]: + unique_counts = sorted( + {int(value) for value in node_counts if int(value) > 0}, + reverse=True, + ) + return { + int(node_count): index + 1 + for index, node_count in enumerate(unique_counts) + } + + +def _combined_model_weight_rows( + result: ClusterDynamicsMLResult, +) -> list[dict[str, object]]: + observed_weights, _predicted_weights = _resolved_population_weights( + result.training_observations, + result.predictions, + frame_timestep_fs=float( + result.dynamics_result.preview.frame_timestep_fs + ), + ) + observed_total_weight = float(np.sum(observed_weights)) + observed_only_weight_by_label = { + entry.label: ( + float(weight) / observed_total_weight + if observed_total_weight > 0.0 + else 0.0 + ) + for entry, weight in zip( + result.training_observations, + observed_weights, + strict=False, + ) + } + weight_by_observed_label: dict[str, tuple[float, str]] = {} + weight_by_prediction_label: dict[str, tuple[float, str]] = {} + if result.saxs_comparison is not None: + for entry in result.saxs_comparison.component_weights: + if str(entry.source).startswith("observed"): + weight_by_observed_label[entry.label] = ( + float(entry.weight), + str(entry.source), + ) + elif str(entry.source) == "predicted": + weight_by_prediction_label[entry.label] = ( + float(entry.weight), + str(entry.source), + ) + + size_ranks = _size_rank_map( + [ + *(entry.node_count for entry in result.training_observations), + *(entry.target_node_count for entry in result.predictions), + ] + ) + rows: list[dict[str, object]] = [] + for entry in result.training_observations: + normalized_weight, model_source = weight_by_observed_label.get( + entry.label, + (0.0, "not_in_model"), + ) + rows.append( + { + "type": "Observed", + "nodes": int(entry.node_count), + "size_rank": int(size_ranks.get(int(entry.node_count), 0)), + "candidate_rank": "", + "label": entry.label, + "observed_only_normalized_weight": float( + observed_only_weight_by_label.get(entry.label, 0.0) + ), + "normalized_weight": float(normalized_weight), + "predicted_population_share": None, + "mean_count_per_frame": float(entry.mean_count_per_frame), + "occupancy_fraction": float(entry.occupancy_fraction), + "mean_lifetime_fs": entry.mean_lifetime_fs, + "std_lifetime_fs": entry.std_lifetime_fs, + "completed_lifetime_count": int( + entry.completed_lifetime_count + ), + "window_truncated_lifetime_count": int( + entry.window_truncated_lifetime_count + ), + "association_rate_per_ps": float( + entry.association_rate_per_ps + ), + "dissociation_rate_per_ps": float( + entry.dissociation_rate_per_ps + ), + "model_source": model_source, + "reference": "", + "notes": "", + } + ) + for entry in result.predictions: + normalized_weight, model_source = weight_by_prediction_label.get( + entry.label, + (0.0, "not_in_model"), + ) + rows.append( + { + "type": "Predicted", + "nodes": int(entry.target_node_count), + "size_rank": int( + size_ranks.get(int(entry.target_node_count), 0) + ), + "candidate_rank": int(entry.rank), + "label": entry.label, + "observed_only_normalized_weight": None, + "normalized_weight": float(normalized_weight), + "predicted_population_share": float( + entry.predicted_population_share + ), + "mean_count_per_frame": float( + entry.predicted_mean_count_per_frame + ), + "occupancy_fraction": float( + entry.predicted_occupancy_fraction + ), + "mean_lifetime_fs": float(entry.predicted_mean_lifetime_fs), + "std_lifetime_fs": None, + "completed_lifetime_count": None, + "window_truncated_lifetime_count": None, + "association_rate_per_ps": float( + entry.predicted_association_rate_per_ps + ), + "dissociation_rate_per_ps": float( + entry.predicted_dissociation_rate_per_ps + ), + "model_source": model_source, + "reference": ( + "" + if entry.source_label is None + else str(entry.source_label) + ), + "notes": str(entry.notes), + } + ) + + rows.sort( + key=lambda row: ( + -float(row["normalized_weight"]), + str(row["type"]), + -int(row["nodes"]), + str(row["label"]), + ) + ) + for index, row in enumerate(rows, start=1): + row["weight_rank"] = int(index) + return rows + + +def launch_clusterdynamicsml_ui( + initial_frames_dir: str | Path | None = None, + *, + energy_file: str | Path | None = None, + project_dir: str | Path | None = None, + clusters_dir: str | Path | None = None, + experimental_data_file: str | Path | None = None, +) -> int: + prepare_saxshell_application_identity() + app = QApplication.instance() + should_exec = app is None + if app is None: + app = QApplication(sys.argv) + configure_saxshell_application(app) + window = ClusterDynamicsMLMainWindow( + initial_frames_dir=( + None + if initial_frames_dir is None + else Path(initial_frames_dir).expanduser().resolve() + ), + initial_energy_file=( + None + if energy_file is None + else Path(energy_file).expanduser().resolve() + ), + initial_project_dir=( + None + if project_dir is None + else Path(project_dir).expanduser().resolve() + ), + initial_clusters_dir=( + None + if clusters_dir is None + else Path(clusters_dir).expanduser().resolve() + ), + initial_experimental_data_file=( + None + if experimental_data_file is None + else Path(experimental_data_file).expanduser().resolve() + ), + ) + window.show() + window.raise_() + _OPEN_WINDOWS.append(window) + window.destroyed.connect(lambda _obj=None: _OPEN_WINDOWS.remove(window)) + if not should_exec: + return 0 + return app.exec() + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="clusterdynamicsml", + description=( + "Predict larger-cluster surrogate stoichiometries, representative " + "structures, and cluster-only SAXS traces from smaller-cluster " + "cluster-dynamics and structure data." + ), + ) + parser.add_argument( + "frames_dir", + nargs="?", + help="Optional extracted frames directory to prefill in the UI.", + ) + parser.add_argument( + "--energy-file", + help="Optional CP2K .ener file to prefill in the UI.", + ) + parser.add_argument( + "--project-dir", + help="Optional SAXSShell project directory to prefill in the UI.", + ) + parser.add_argument( + "--clusters-dir", + help="Optional smaller-cluster structure directory to prefill in the UI.", + ) + parser.add_argument( + "--experimental-data", + help="Optional experimental SAXS data file to prefill in the UI.", + ) + return parser + + +def main(argv: list[str] | None = None) -> int: + parser = _build_arg_parser() + args = parser.parse_args(argv) + return launch_clusterdynamicsml_ui( + getattr(args, "frames_dir", None), + energy_file=getattr(args, "energy_file", None), + project_dir=getattr(args, "project_dir", None), + clusters_dir=getattr(args, "clusters_dir", None), + experimental_data_file=getattr(args, "experimental_data", None), + ) diff --git a/src/saxshell/clusterdynamicsml/ui/plot_panel.py b/src/saxshell/clusterdynamicsml/ui/plot_panel.py new file mode 100644 index 0000000..c294335 --- /dev/null +++ b/src/saxshell/clusterdynamicsml/ui/plot_panel.py @@ -0,0 +1,1128 @@ +from __future__ import annotations + +from collections import Counter +from pathlib import Path + +import numpy as np +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.backends.backend_qtagg import ( + NavigationToolbar2QT as NavigationToolbar, +) +from matplotlib.figure import Figure +from PySide6.QtWidgets import ( + QCheckBox, + QComboBox, + QHBoxLayout, + QLabel, + QPushButton, + QVBoxLayout, + QWidget, +) + +from saxshell.cluster.clusternetwork import stoichiometry_label +from saxshell.clusterdynamicsml.workflow import ( + ClusterDynamicsMLResult, + _resolved_population_weights, +) +from saxshell.saxs.debye.profiles import scan_structure_element_counts +from saxshell.saxs.project_manager.prior_plot import ( + list_secondary_filter_elements, + plot_md_prior_histogram, +) +from saxshell.saxs.stoichiometry import parse_stoich_label + +_EXPERIMENTAL_COLOR = "#111111" +_OBSERVED_MODEL_COLOR = "#1f77b4" +_COMBINED_MODEL_COLOR = "#ff7f0e" +_HISTOGRAM_CMAP = "summer" +_HISTOGRAM_MODES = ( + ("Structure Fraction", "structure_fraction"), + ("Atom Fraction", "atom_fraction"), + ("Solvent Sort - Structure Fraction", "solvent_sort_structure_fraction"), + ("Solvent Sort - Atom Fraction", "solvent_sort_atom_fraction"), +) +_STRUCTURE_FILE_SUFFIXES = {".xyz", ".pdb"} + + +class ClusterDynamicsMLHistogramPanel(QWidget): + """Plot Project Setup-style stoichiometry histograms for AI + results.""" + + def __init__( + self, + *, + include_predictions: bool | None = None, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._fixed_include_predictions = include_predictions + self._include_predictions = ( + False if include_predictions is None else bool(include_predictions) + ) + self._result: ClusterDynamicsMLResult | None = None + self._histogram_payloads: dict[bool, dict[str, object] | None] = { + False: None, + True: None, + } + self._build_ui() + self.refresh_plot() + + def _build_ui(self) -> None: + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(8) + + controls_widget = QWidget() + controls = QHBoxLayout(controls_widget) + controls.setContentsMargins(0, 0, 0, 0) + controls.setSpacing(8) + + self.population_label = QLabel("Population set") + controls.addWidget(self.population_label) + self.population_combo = QComboBox() + self.population_combo.addItem("Observed", False) + self.population_combo.addItem("Observed + Surrogate", True) + self.population_combo.currentIndexChanged.connect( + self._on_population_changed + ) + controls.addWidget(self.population_combo) + + controls.addWidget(QLabel("Mode")) + self.mode_combo = QComboBox() + for label, mode in _HISTOGRAM_MODES: + self.mode_combo.addItem(label, mode) + self.mode_combo.currentIndexChanged.connect(self._on_mode_changed) + controls.addWidget(self.mode_combo) + + self.secondary_label = QLabel("Secondary atom") + controls.addWidget(self.secondary_label) + self.secondary_combo = QComboBox() + self.secondary_combo.currentIndexChanged.connect( + lambda _index: self.refresh_plot() + ) + controls.addWidget(self.secondary_combo) + controls.addStretch(1) + layout.addWidget(controls_widget) + + self.figure = Figure(figsize=(9.2, 7.0)) + self.canvas = FigureCanvas(self.figure) + layout.addWidget(NavigationToolbar(self.canvas, self)) + layout.addWidget(self.canvas, stretch=1) + self._update_population_control_state() + + def set_result(self, result: ClusterDynamicsMLResult | None) -> None: + self._result = result + if result is None: + self._histogram_payloads = {False: None, True: None} + else: + self._histogram_payloads = { + False: _build_population_histogram_payload( + result, + include_predictions=False, + ), + True: _build_population_histogram_payload( + result, + include_predictions=True, + ), + } + self._refresh_secondary_elements() + self.refresh_plot() + + def refresh_plot(self) -> None: + self.figure.clear() + if self._result is None: + axis = self.figure.add_subplot(111) + axis.text( + 0.5, + 0.5, + "Run the surrogate workflow to plot cluster stoichiometry\n" + "histograms for the current result.", + ha="center", + va="center", + transform=axis.transAxes, + ) + axis.set_axis_off() + self.canvas.draw_idle() + return + + include_predictions = self._current_include_predictions() + histogram_payload = self._histogram_payloads.get(include_predictions) + if histogram_payload is None: + axis = self.figure.add_subplot(111) + axis.text( + 0.5, + 0.5, + "No stoichiometry histogram data are available for the current result.", + ha="center", + va="center", + transform=axis.transAxes, + ) + axis.set_axis_off() + self.canvas.draw_idle() + return + + axis = self.figure.add_subplot(111) + mode = self._mode() + secondary_element = self._secondary_element() + if mode.startswith("solvent_sort") and secondary_element is None: + axis.text( + 0.5, + 0.5, + "No secondary atom counts are available for solvent-sort histograms.", + ha="center", + va="center", + transform=axis.transAxes, + ) + axis.set_axis_off() + self.canvas.draw_idle() + return + + try: + plot_md_prior_histogram( + histogram_payload, + mode=mode, + secondary_element=secondary_element, + cmap=_HISTOGRAM_CMAP, + ax=axis, + ) + except Exception as exc: + axis.text( + 0.5, + 0.5, + str(exc), + ha="center", + va="center", + wrap=True, + transform=axis.transAxes, + ) + axis.set_axis_off() + self.canvas.draw_idle() + return + + summary_label = ( + "Observed + surrogate populations" + if include_predictions + else "Observed populations" + ) + self.figure.suptitle(summary_label, y=0.995) + self.figure.tight_layout(rect=(0.0, 0.0, 1.0, 0.965)) + self.canvas.draw_idle() + + def _mode(self) -> str: + current = self.mode_combo.currentData() + return "structure_fraction" if current is None else str(current) + + def _secondary_element(self) -> str | None: + text = self.secondary_combo.currentText().strip() + return text or None + + def _current_include_predictions(self) -> bool: + if self._fixed_include_predictions is not None: + return bool(self._fixed_include_predictions) + current = self.population_combo.currentData() + return bool(self._include_predictions if current is None else current) + + def _on_mode_changed(self) -> None: + self._update_secondary_control_state() + self.refresh_plot() + + def _on_population_changed(self) -> None: + self._include_predictions = self._current_include_predictions() + self._refresh_secondary_elements() + self.refresh_plot() + + def _refresh_secondary_elements(self) -> None: + current = self._secondary_element() + elements = ( + [] + if self._histogram_payloads.get( + self._current_include_predictions() + ) + is None + else list_secondary_filter_elements( + self._histogram_payloads[self._current_include_predictions()] + ) + ) + self.secondary_combo.blockSignals(True) + self.secondary_combo.clear() + self.secondary_combo.addItems(elements) + if current and current in elements: + self.secondary_combo.setCurrentText(current) + elif elements: + self.secondary_combo.setCurrentIndex(0) + self.secondary_combo.blockSignals(False) + self._update_secondary_control_state() + + def _update_secondary_control_state(self) -> None: + needs_secondary = self._mode().startswith("solvent_sort") + has_options = self.secondary_combo.count() > 0 + enabled = needs_secondary and has_options + self.secondary_label.setEnabled(enabled) + self.secondary_combo.setEnabled(enabled) + + def _update_population_control_state(self) -> None: + visible = self._fixed_include_predictions is None + self.population_label.setVisible(visible) + self.population_combo.setVisible(visible) + + +class ClusterDynamicsMLPlotPanel(QWidget): + """Plot observed-only, surrogate, and component SAXS traces.""" + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._result: ClusterDynamicsMLResult | None = None + self._legend_line_map: dict[object, object] = {} + self._legend_handle_lookup: dict[str, object] = {} + self._trace_line_lookup: dict[str, object] = {} + self._trace_visibility: dict[str, bool] = {} + self._component_trace_keys: list[str] = [] + self._build_ui() + self.refresh_plot() + + def _build_ui(self) -> None: + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(8) + + controls_widget = QWidget() + controls = QHBoxLayout(controls_widget) + controls.setContentsMargins(0, 0, 0, 0) + controls.setSpacing(8) + self.log_x_checkbox = QCheckBox("Log X") + self.log_x_checkbox.setChecked(True) + self.log_x_checkbox.toggled.connect(self.refresh_plot) + controls.addWidget(self.log_x_checkbox) + self.log_y_checkbox = QCheckBox("Log Y") + self.log_y_checkbox.setChecked(True) + self.log_y_checkbox.toggled.connect(self.refresh_plot) + controls.addWidget(self.log_y_checkbox) + self.legend_toggle_button = QPushButton("Legend") + self.legend_toggle_button.setCheckable(True) + self.legend_toggle_button.setChecked(True) + self.legend_toggle_button.toggled.connect(self.refresh_plot) + controls.addWidget(self.legend_toggle_button) + self.component_traces_button = QPushButton("Hide Component Traces") + self.component_traces_button.clicked.connect( + self._toggle_all_component_traces + ) + controls.addWidget(self.component_traces_button) + controls.addStretch(1) + layout.addWidget(controls_widget) + + self.figure = Figure(figsize=(9.2, 7.0)) + self.canvas = FigureCanvas(self.figure) + self.canvas.mpl_connect("pick_event", self._handle_legend_pick) + layout.addWidget(NavigationToolbar(self.canvas, self)) + layout.addWidget(self.canvas, stretch=1) + self._update_component_trace_control_state() + + def set_result(self, result: ClusterDynamicsMLResult | None) -> None: + self._result = result + self.refresh_plot() + + def refresh_plot(self) -> None: + self.figure.clear() + self._legend_line_map.clear() + self._legend_handle_lookup.clear() + self._trace_line_lookup.clear() + self._component_trace_keys = [] + if self._result is None: + axis = self.figure.add_subplot(111) + axis.text( + 0.5, + 0.5, + "Run the surrogate workflow to plot the SAXS\n" + "form-factor comparison.", + ha="center", + va="center", + transform=axis.transAxes, + ) + axis.set_axis_off() + self._update_component_trace_control_state() + self.canvas.draw_idle() + return + + axis = self.figure.add_subplot(111) + plotted_lines: list[object] = [] + + observed_model = _build_saxs_model( + self._result, + include_predictions=False, + ) + combined_model = _build_saxs_model( + self._result, + include_predictions=True, + ) + if observed_model is None and combined_model is None: + axis.text( + 0.5, + 0.5, + "No SAXS form-factor model is available for the current result.", + ha="center", + va="center", + transform=axis.transAxes, + ) + axis.set_axis_off() + else: + experimental_plotted = False + plotted_any = False + if observed_model is not None: + line = _plot_saxs_trace_line( + axis, + q_values=observed_model["q_values"], + intensity=observed_model["model_intensity"], + color=_OBSERVED_MODEL_COLOR, + linewidth=1.4, + linestyle="--", + label="observed-only model", + gid="observed-only model", + visible=self._trace_visibility.get( + "observed-only model", True + ), + ) + if line is not None: + plotted_any = True + plotted_lines.append(line) + self._trace_line_lookup["observed-only model"] = line + self._trace_visibility.setdefault( + "observed-only model", + True, + ) + if observed_model["experimental_intensity"] is not None: + line = _plot_saxs_trace_line( + axis, + q_values=observed_model["q_values"], + intensity=observed_model["experimental_intensity"], + color=_EXPERIMENTAL_COLOR, + linewidth=1.2, + alpha=0.75, + label="experimental", + gid="experimental", + visible=self._trace_visibility.get( + "experimental", True + ), + ) + if line is not None: + experimental_plotted = True + plotted_lines.append(line) + self._trace_line_lookup["experimental"] = line + self._trace_visibility.setdefault("experimental", True) + if combined_model is not None: + line = _plot_saxs_trace_line( + axis, + q_values=combined_model["q_values"], + intensity=combined_model["model_intensity"], + color=_COMBINED_MODEL_COLOR, + linewidth=1.8, + label="observed + surrogate model", + gid="observed + surrogate model", + visible=self._trace_visibility.get( + "observed + surrogate model", + True, + ), + ) + if line is not None: + plotted_any = True + plotted_lines.append(line) + self._trace_line_lookup["observed + surrogate model"] = ( + line + ) + self._trace_visibility.setdefault( + "observed + surrogate model", + True, + ) + if ( + combined_model["experimental_intensity"] is not None + and not experimental_plotted + ): + line = _plot_saxs_trace_line( + axis, + q_values=combined_model["q_values"], + intensity=combined_model["experimental_intensity"], + color=_EXPERIMENTAL_COLOR, + linewidth=1.2, + alpha=0.75, + label="experimental", + gid="experimental", + visible=self._trace_visibility.get( + "experimental", True + ), + ) + if line is not None: + experimental_plotted = True + plotted_lines.append(line) + self._trace_line_lookup["experimental"] = line + self._trace_visibility.setdefault("experimental", True) + + for component_trace in _build_saxs_component_traces(self._result): + line = _plot_saxs_trace_line( + axis, + q_values=component_trace["q_values"], + intensity=component_trace["intensity"], + color=str(component_trace["color"]), + linewidth=float(component_trace["linewidth"]), + linestyle=str(component_trace["linestyle"]), + alpha=float(component_trace["alpha"]), + label=str(component_trace["label"]), + gid=str(component_trace["key"]), + visible=self._trace_visibility.get( + str(component_trace["key"]), + True, + ), + ) + if line is None: + continue + plotted_any = True + component_key = str(component_trace["key"]) + self._trace_line_lookup[component_key] = line + self._trace_visibility.setdefault( + component_key, line.get_visible() + ) + self._component_trace_keys.append(component_key) + plotted_lines.append(line) + + if not plotted_any: + axis.text( + 0.5, + 0.5, + "No positive SAXS values are available for the current result.", + ha="center", + va="center", + transform=axis.transAxes, + ) + axis.set_axis_off() + else: + axis.set_xscale( + "log" if self.log_x_checkbox.isChecked() else "linear" + ) + axis.set_yscale( + "log" if self.log_y_checkbox.isChecked() else "linear" + ) + axis.set_xlabel("q (Å⁻¹)") + axis.set_ylabel("Intensity (arb. units)") + title = "SAXS Form Factor Models" + title_bits: list[str] = [] + if ( + observed_model is not None + and observed_model["rmse"] is not None + ): + title_bits.append( + f"observed-only RMSE {observed_model['rmse']:.4g}" + ) + if ( + combined_model is not None + and combined_model["rmse"] is not None + ): + title_bits.append( + f"with surrogate RMSE {combined_model['rmse']:.4g}" + ) + if title_bits: + title += " (" + "; ".join(title_bits) + ")" + axis.set_title(title) + if self.legend_toggle_button.isChecked(): + self._build_interactive_legend(axis, plotted_lines) + self._refresh_axes(axis) + + self._update_component_trace_control_state() + self.figure.tight_layout() + self.canvas.draw_idle() + + def _build_interactive_legend(self, axis, lines: list[object]) -> None: + legend_columns = max(1, int(np.ceil(len(lines) / 5.0))) + legend = axis.legend( + lines, + [line.get_label() for line in lines], + fontsize="small", + loc="upper right", + bbox_to_anchor=(0.985, 0.985), + borderaxespad=0.3, + framealpha=0.9, + ncols=legend_columns, + columnspacing=0.9, + handlelength=1.5, + ) + if legend is None: + return + legend_handles = getattr(legend, "legend_handles", None) + if legend_handles is None: + legend_handles = getattr(legend, "legendHandles", []) + for legend_handle, original_line in zip(legend_handles, lines): + if hasattr(legend_handle, "set_picker"): + legend_handle.set_picker(True) + legend_handle.set_pickradius(6) + legend_handle.set_alpha( + 1.0 if original_line.get_visible() else 0.25 + ) + self._legend_line_map[legend_handle] = original_line + line_key = str(original_line.get_gid() or "").strip() + if line_key: + self._legend_handle_lookup[line_key] = legend_handle + + def _handle_legend_pick(self, event) -> None: + original_line = self._legend_line_map.get(event.artist) + if original_line is None: + return + is_visible = not original_line.get_visible() + original_line.set_visible(is_visible) + line_key = str(original_line.get_gid() or "").strip() + if line_key: + self._trace_visibility[line_key] = is_visible + if hasattr(event.artist, "set_alpha"): + event.artist.set_alpha(1.0 if is_visible else 0.25) + self._update_component_trace_control_state() + self._refresh_axes() + self.canvas.draw_idle() + + def _toggle_all_component_traces(self) -> None: + if not self._component_trace_keys: + return + any_visible = any( + self._trace_visibility.get(component_key, True) + for component_key in self._component_trace_keys + ) + target_visible = not any_visible + for component_key in self._component_trace_keys: + self._trace_visibility[component_key] = target_visible + line = self._trace_line_lookup.get(component_key) + if line is not None: + line.set_visible(target_visible) + legend_line = self._legend_handle_lookup.get(component_key) + if legend_line is not None and hasattr(legend_line, "set_alpha"): + legend_line.set_alpha(1.0 if target_visible else 0.25) + self._update_component_trace_control_state() + self._refresh_axes() + self.canvas.draw_idle() + + def _update_component_trace_control_state(self) -> None: + has_components = bool(self._component_trace_keys) + any_visible = any( + self._trace_visibility.get(component_key, True) + for component_key in self._component_trace_keys + ) + self.component_traces_button.setEnabled(has_components) + self.component_traces_button.setText( + "Hide Component Traces" if any_visible else "Show Component Traces" + ) + + def _refresh_axes(self, axis=None) -> None: + active_axis = axis + if active_axis is None: + if not self.figure.axes: + return + active_axis = self.figure.axes[0] + if not hasattr(active_axis, "relim"): + return + active_axis.relim(visible_only=True) + active_axis.autoscale_view() + + +def _plot_saxs_trace_line( + axis, + *, + q_values: np.ndarray, + intensity: np.ndarray, + color: str, + linewidth: float, + label: str, + linestyle: str = "-", + alpha: float = 1.0, + gid: str | None = None, + visible: bool = True, +): + q_array = np.asarray(q_values, dtype=float) + intensity_array = np.asarray(intensity, dtype=float) + mask = np.isfinite(q_array) & np.isfinite(intensity_array) + mask &= q_array > 0.0 + mask &= intensity_array > 0.0 + if not np.any(mask): + return None + (line,) = axis.plot( + q_array[mask], + intensity_array[mask], + color=color, + linewidth=linewidth, + linestyle=linestyle, + alpha=alpha, + label=label, + visible=visible, + ) + if gid: + line.set_gid(gid) + return line + + +def _build_saxs_component_traces( + result: ClusterDynamicsMLResult, +) -> list[dict[str, object]]: + comparison = result.saxs_comparison + if comparison is None: + return [] + component_entries = [ + entry + for entry in comparison.component_weights + if entry.profile_path is not None + and Path(entry.profile_path).is_file() + ] + if not component_entries: + return [] + + predicted_count = sum( + 1 for entry in component_entries if str(entry.source) == "predicted" + ) + observed_count = sum( + 1 + for entry in component_entries + if str(entry.source).startswith("observed") + ) + predicted_index = 0 + observed_index = 0 + scale_factor = ( + float(comparison.scale_factor) + if float(comparison.scale_factor) > 0.0 + else 1.0 + ) + traces: list[dict[str, object]] = [] + for index, entry in enumerate(component_entries): + try: + raw_data = np.loadtxt(entry.profile_path, comments="#") + except Exception: + continue + if raw_data.size == 0: + continue + if raw_data.ndim == 1: + raw_data = raw_data.reshape(1, -1) + q_values = np.asarray(raw_data[:, 0], dtype=float) + intensity = np.asarray(raw_data[:, 1], dtype=float) + weighted_intensity = ( + intensity * max(float(entry.weight), 0.0) * scale_factor + ) + source = str(entry.source) + if source == "predicted": + color = _predicted_component_color( + predicted_index, predicted_count + ) + predicted_index += 1 + label = f"surrogate component: {entry.label}" + linestyle = "-" + else: + color = _observed_component_color(observed_index, observed_count) + observed_index += 1 + label = f"observed component: {entry.label}" + linestyle = ":" + traces.append( + { + "key": f"component:{index}:{source}:{entry.label}", + "label": label, + "q_values": q_values, + "intensity": weighted_intensity, + "color": color, + "linestyle": linestyle, + "linewidth": 1.0, + "alpha": 0.85 if source == "predicted" else 0.65, + } + ) + return traces + + +def _predicted_component_color(index: int, total: int) -> str: + palette = ( + "#c95d38", + "#e07a5f", + "#f2a65a", + "#d1495b", + "#edae49", + "#c1121f", + ) + if total <= 0: + return palette[0] + return palette[index % len(palette)] + + +def _observed_component_color(index: int, total: int) -> str: + palette = ( + "#7f8c8d", + "#95a5a6", + "#5d6d7e", + "#85929e", + ) + if total <= 0: + return palette[0] + return palette[index % len(palette)] + + +def _distribution_entries( + result: ClusterDynamicsMLResult, + *, + include_predictions: bool, +) -> list[dict[str, float | None]]: + entries: list[dict[str, float | None]] = [] + observed_weights, predicted_weights = _resolved_population_weights( + result.training_observations, + result.predictions, + frame_timestep_fs=float( + result.dynamics_result.preview.frame_timestep_fs + ), + ) + for row, weight in zip( + result.training_observations, + observed_weights, + strict=False, + ): + if weight <= 0.0: + continue + entries.append( + { + "node_count": float(row.node_count), + "mean_lifetime_fs": row.mean_lifetime_fs, + "mean_count_per_frame": float(row.mean_count_per_frame), + "mean_max_radius": float(row.mean_max_radius), + "weight": weight, + } + ) + if include_predictions: + for item, weight in zip( + result.predictions, + predicted_weights, + strict=False, + ): + if weight <= 0.0: + continue + entries.append( + { + "node_count": float(item.target_node_count), + "mean_lifetime_fs": float(item.predicted_mean_lifetime_fs), + "mean_count_per_frame": float( + item.predicted_mean_count_per_frame + ), + "mean_max_radius": float(item.predicted_mean_max_radius), + "weight": weight, + } + ) + total_weight = sum(float(entry["weight"]) for entry in entries) + if total_weight <= 0.0: + return [] + for entry in entries: + entry["normalized_weight"] = float(entry["weight"]) / total_weight + return entries + + +def _build_population_histogram_payload( + result: ClusterDynamicsMLResult, + *, + include_predictions: bool, +) -> dict[str, object] | None: + structures: dict[str, dict[str, dict[str, object]]] = {} + secondary_elements: set[str] = set() + total_population = 0.0 + observed_label_elements = _observed_label_elements(result) + observed_weights, predicted_weights = _resolved_population_weights( + result.training_observations, + result.predictions, + frame_timestep_fs=float( + result.dynamics_result.preview.frame_timestep_fs + ), + ) + + for observation, base_count in zip( + result.training_observations, + observed_weights, + strict=False, + ): + base_count = max(float(base_count), 0.0) + if base_count <= 0.0: + continue + raw_payloads = _observed_structure_payloads(observation) + if not raw_payloads: + continue + scaled_payloads = _scale_motif_payloads( + raw_payloads, base_count=base_count + ) + if not scaled_payloads: + continue + structures.setdefault(observation.label, {}).update(scaled_payloads) + total_population += sum( + float(payload["count"]) for payload in scaled_payloads.values() + ) + secondary_elements.update( + _secondary_elements_from_payloads(scaled_payloads) + ) + + if include_predictions: + for prediction, base_count in zip( + result.predictions, + predicted_weights, + strict=False, + ): + base_count = max(float(base_count), 0.0) + if base_count <= 0.0: + continue + motif_name = f"predicted_rank_{int(prediction.rank):02d}" + structure_label = _predicted_structure_label( + prediction, + observed_label_elements=observed_label_elements, + ) + motif_payload = _predicted_structure_payload( + prediction, + count=base_count, + structure_label=structure_label, + ) + structures.setdefault(structure_label, {})[ + motif_name + ] = motif_payload + total_population += float(motif_payload["count"]) + secondary_elements.update( + _secondary_elements_from_payloads({motif_name: motif_payload}) + ) + + if total_population <= 0.0 or not structures: + return None + + if secondary_elements: + for motif_payloads in structures.values(): + for payload in motif_payloads.values(): + distributions = dict( + payload.get("secondary_atom_distributions", {}) + ) + count = float(payload.get("count", 0.0) or 0.0) + for element in secondary_elements: + distributions.setdefault(element, {"0": count}) + payload["secondary_atom_distributions"] = distributions + + for motif_payloads in structures.values(): + for payload in motif_payloads.values(): + payload["weight"] = float(payload["count"]) / total_population + + return { + "origin": "clusterdynamicsml", + "total_files": float(total_population), + "available_elements": sorted(secondary_elements), + "structures": structures, + } + + +def _observed_structure_payloads( + observation, +) -> dict[str, dict[str, object]]: + structure_dir = Path(observation.structure_dir).expanduser() + raw_payloads: dict[str, dict[str, object]] = {} + + if structure_dir.is_dir(): + motif_dirs = sorted( + path + for path in structure_dir.iterdir() + if path.is_dir() and path.name.startswith("motif_") + ) + if motif_dirs: + for motif_dir in motif_dirs: + payload = _structure_payload_from_files( + _structure_files_in_dir(motif_dir), + label=observation.label, + ) + if payload is not None: + raw_payloads[motif_dir.name] = payload + else: + payload = _structure_payload_from_files( + _structure_files_in_dir(structure_dir), + label=observation.label, + ) + if payload is not None: + raw_payloads["no_motif"] = payload + + representative_path = ( + None + if observation.representative_path is None + else Path(observation.representative_path).expanduser() + ) + if ( + not raw_payloads + and representative_path is not None + and representative_path.is_file() + ): + payload = _structure_payload_from_files( + [representative_path], + label=observation.label, + ) + if payload is not None: + raw_payloads["no_motif"] = payload + + return raw_payloads + + +def _structure_files_in_dir(directory: Path) -> list[Path]: + return sorted( + path + for path in directory.iterdir() + if path.is_file() and path.suffix.lower() in _STRUCTURE_FILE_SUFFIXES + ) + + +def _structure_payload_from_files( + file_paths: list[Path], + *, + label: str, +) -> dict[str, object] | None: + element_counts = [] + for file_path in file_paths: + try: + element_counts.append(scan_structure_element_counts(file_path)) + except Exception: + continue + if not element_counts: + return None + + label_elements = set(parse_stoich_label(label).keys()) + secondary_elements = sorted( + { + element + for counts in element_counts + for element in counts + if element not in label_elements + } + ) + secondary_distributions: dict[str, dict[str, float]] = {} + for element in secondary_elements: + buckets: Counter[str] = Counter() + for counts in element_counts: + buckets[str(int(counts.get(element, 0)))] += 1 + secondary_distributions[element] = { + segment: float(buckets[segment]) + for segment in sorted(buckets, key=lambda value: int(value)) + } + + return { + "count": float(len(element_counts)), + "weight": 0.0, + "secondary_atom_distributions": secondary_distributions, + } + + +def _scale_motif_payloads( + raw_payloads: dict[str, dict[str, object]], + *, + base_count: float, +) -> dict[str, dict[str, object]]: + total_raw_count = sum( + float(payload.get("count", 0.0) or 0.0) + for payload in raw_payloads.values() + ) + if total_raw_count <= 0.0 or base_count <= 0.0: + return {} + + scale = float(base_count) / total_raw_count + scaled_payloads: dict[str, dict[str, object]] = {} + for motif_name, payload in raw_payloads.items(): + scaled_payloads[motif_name] = { + "count": float(payload.get("count", 0.0) or 0.0) * scale, + "weight": 0.0, + "secondary_atom_distributions": { + element: { + segment: float(value) * scale + for segment, value in dict(distribution).items() + } + for element, distribution in dict( + payload.get("secondary_atom_distributions", {}) + ).items() + }, + } + return scaled_payloads + + +def _predicted_structure_payload( + prediction, + *, + count: float, + structure_label: str, +) -> dict[str, object]: + element_counts = Counter( + str(element) for element in prediction.generated_elements + ) + label_elements = set(parse_stoich_label(structure_label).keys()) + secondary_distributions = { + element: {str(int(element_counts[element])): float(count)} + for element in sorted(element_counts) + if element not in label_elements + } + return { + "count": float(count), + "weight": 0.0, + "secondary_atom_distributions": secondary_distributions, + } + + +def _secondary_elements_from_payloads( + payloads: dict[str, dict[str, object]], +) -> set[str]: + return { + str(element) + for payload in payloads.values() + for element in dict(payload.get("secondary_atom_distributions", {})) + } + + +def _observed_label_elements(result: ClusterDynamicsMLResult) -> set[str]: + return { + str(element) + for observation in result.training_observations + for element in parse_stoich_label(observation.label) + } + + +def _predicted_structure_label( + prediction, + *, + observed_label_elements: set[str], +) -> str: + primary_counts = { + str(element): int(count) + for element, count in dict(prediction.element_counts).items() + if int(count) > 0 + and ( + not observed_label_elements + or str(element) in observed_label_elements + ) + } + if not primary_counts: + primary_counts = { + str(element): int(count) + for element, count in dict(prediction.element_counts).items() + if int(count) > 0 + } + return stoichiometry_label(primary_counts) + + +def _build_saxs_model( + result: ClusterDynamicsMLResult, + *, + include_predictions: bool, +) -> dict[str, np.ndarray | float | None] | None: + comparison = result.saxs_comparison + if comparison is None: + return None + q_values = np.asarray(comparison.q_values, dtype=float) + if q_values.size == 0: + return None + experimental_intensity = ( + None + if comparison.experimental_intensity is None + else np.asarray(comparison.experimental_intensity, dtype=float) + ) + if include_predictions: + fitted_model = np.asarray( + comparison.fitted_model_intensity, dtype=float + ) + rmse = comparison.rmse + else: + if comparison.observed_fitted_model_intensity is None: + return None + fitted_model = np.asarray( + comparison.observed_fitted_model_intensity, + dtype=float, + ) + rmse = comparison.observed_rmse + return { + "q_values": q_values, + "model_intensity": fitted_model, + "experimental_intensity": experimental_intensity, + "rmse": rmse, + } diff --git a/src/saxshell/clusterdynamicsml/workflow.py b/src/saxshell/clusterdynamicsml/workflow.py new file mode 100644 index 0000000..77812bf --- /dev/null +++ b/src/saxshell/clusterdynamicsml/workflow.py @@ -0,0 +1,4353 @@ +from __future__ import annotations + +import json +import math +import re +from collections import Counter, defaultdict +from dataclasses import dataclass +from itertools import combinations +from pathlib import Path +from typing import Callable + +import numpy as np + +from saxshell.cluster import PairCutoffDefinitions +from saxshell.cluster.clusternetwork import stoichiometry_label +from saxshell.clusterdynamics import ( + ClusterDynamicsResult, + ClusterDynamicsSelectionPreview, + ClusterDynamicsWorkflow, + ClusterLifetimeSummary, +) +from saxshell.saxs.debye.profiles import ( + compute_debye_intensity, + load_structure_file, +) +from saxshell.saxs.project_manager import ( + ExperimentalDataSummary, + SAXSProjectManager, + build_project_paths, + load_built_component_q_range, + load_experimental_data_file, +) +from saxshell.structure import AtomTypeDefinitions + +PredictionProgressCallback = Callable[[str], None] +_STOICHIOMETRY_TOKEN_PATTERN = re.compile(r"([A-Z][a-z]*)(\d*)") +_DEFAULT_Q_MIN = 0.02 +_DEFAULT_Q_MAX = 1.20 +_DEFAULT_Q_POINTS = 250 +_DEFAULT_SHARE_THRESHOLD = 0.02 +_RIDGE_REGULARIZATION = 1e-6 + + +@dataclass(slots=True) +class ClusterStructureObservation: + """Aggregated structure information for one observed + stoichiometry.""" + + label: str + node_count: int + element_counts: dict[str, int] + file_count: int + representative_path: Path | None + structure_dir: Path + motifs: tuple[str, ...] + mean_atom_count: float + mean_radius_of_gyration: float + mean_max_radius: float + mean_semiaxis_a: float + mean_semiaxis_b: float + mean_semiaxis_c: float + + +@dataclass(slots=True) +class ClusterDynamicsMLTrainingObservation: + """Joined kinetics and structure descriptors for one label.""" + + label: str + node_count: int + cluster_size: int + element_counts: dict[str, int] + file_count: int + representative_path: Path | None + structure_dir: Path + motifs: tuple[str, ...] + mean_atom_count: float + mean_radius_of_gyration: float + mean_max_radius: float + mean_semiaxis_a: float + mean_semiaxis_b: float + mean_semiaxis_c: float + total_observations: int + occupied_frames: int + mean_count_per_frame: float + occupancy_fraction: float + association_events: int + dissociation_events: int + association_rate_per_ps: float + dissociation_rate_per_ps: float + completed_lifetime_count: int + window_truncated_lifetime_count: int + mean_lifetime_fs: float | None + std_lifetime_fs: float | None + + @property + def stability_weight(self) -> float: + finite_lifetime = ( + 0.0 if self.mean_lifetime_fs is None else self.mean_lifetime_fs + ) + return float( + 1.0 + + max(self.file_count, 0) / 10.0 + + max(self.completed_lifetime_count, 0) / 5.0 + + max(self.mean_count_per_frame, 0.0) * 5.0 + + max(self.occupancy_fraction, 0.0) * 5.0 + + finite_lifetime / 100.0 + ) + + +@dataclass(slots=True) +class PredictedClusterCandidate: + """Predicted larger-cluster surrogate candidate.""" + + target_node_count: int + rank: int + label: str + element_counts: dict[str, int] + predicted_mean_count_per_frame: float + predicted_occupancy_fraction: float + predicted_mean_lifetime_fs: float + predicted_association_rate_per_ps: float + predicted_dissociation_rate_per_ps: float + predicted_mean_radius_of_gyration: float + predicted_mean_max_radius: float + predicted_mean_semiaxis_a: float + predicted_mean_semiaxis_b: float + predicted_mean_semiaxis_c: float + predicted_population_share: float + predicted_stability_score: float + source_label: str | None + notes: str + generated_elements: tuple[str, ...] + generated_coordinates: np.ndarray + + +@dataclass(slots=True) +class SAXSComponentWeight: + label: str + weight: float + source: str + profile_path: Path | None = None + structure_path: Path | None = None + + +@dataclass(slots=True) +class ClusterDynamicsMLSAXSComparison: + """Cluster-only surrogate SAXS comparison trace.""" + + q_values: np.ndarray + observed_raw_model_intensity: np.ndarray | None + observed_fitted_model_intensity: np.ndarray | None + observed_rmse: float | None + raw_model_intensity: np.ndarray + fitted_model_intensity: np.ndarray + experimental_intensity: np.ndarray | None + residuals: np.ndarray | None + scale_factor: float + offset: float + rmse: float | None + component_weights: tuple[SAXSComponentWeight, ...] + experimental_data_path: Path | None + component_output_dir: Path | None = None + surrogate_structure_dir: Path | None = None + + +@dataclass(slots=True) +class _ResolvedSAXSComponent: + label: str + weight: float + source: str + trace: np.ndarray + profile_path: Path | None = None + structure_path: Path | None = None + + +@dataclass(slots=True) +class ClusterDynamicsMLPreview: + """Preview metadata for the separate prediction workflow.""" + + dynamics_preview: ClusterDynamicsSelectionPreview + clusters_dir: Path | None + project_dir: Path | None + experimental_data_path: Path | None + structure_label_count: int + total_structure_files: int + observed_node_counts: tuple[int, ...] + target_node_counts: tuple[int, ...] + warnings: tuple[str, ...] = () + + +@dataclass(slots=True) +class ClusterDynamicsMLResult: + """End-to-end result for the experimental prediction workflow.""" + + dynamics_result: ClusterDynamicsResult + preview: ClusterDynamicsMLPreview + structure_observations: tuple[ClusterStructureObservation, ...] + training_observations: tuple[ClusterDynamicsMLTrainingObservation, ...] + predictions: tuple[PredictedClusterCandidate, ...] + saxs_comparison: ClusterDynamicsMLSAXSComparison | None + max_observed_node_count: int + max_predicted_node_count: int | None + prediction_population_share_threshold: float + + +@dataclass(slots=True) +class _PropertyModel: + coefficients: np.ndarray | None + constant_value: float | None + transform: str + default_value: float + lower_bound: float | None = None + upper_bound: float | None = None + + def predict(self, features: np.ndarray) -> float: + if self.constant_value is not None: + transformed = float(self.constant_value) + elif self.coefficients is None: + transformed = float(self.default_value) + else: + transformed = float( + np.asarray(features, dtype=float) @ self.coefficients + ) + if self.transform == "log1p": + value = float(np.expm1(transformed)) + else: + value = transformed + if self.lower_bound is not None: + value = max(self.lower_bound, value) + if self.upper_bound is not None: + value = min(self.upper_bound, value) + return float(value) + + +@dataclass(slots=True) +class _TrainingGeometryStatistics: + atom_type_by_element: dict[str, str] + node_elements: tuple[str, ...] + tracked_atom_types: tuple[str, ...] + node_bond_length: float + bond_length_medians: dict[tuple[str, str], float] + contact_distance_medians: dict[tuple[str, str], float] + geometry_contact_distance_medians: dict[tuple[str, str], float] + node_angle_medians: dict[tuple[str, str], float] + node_coordination_medians: dict[str, float] + non_node_node_coordination_medians: dict[str, float] + atom_coordination_medians: dict[tuple[str, str], float] + + +class ClusterDynamicsMLWorkflow: + """Predict larger-cluster surrogate states from smaller-cluster + data.""" + + def __init__( + self, + frames_dir: str | Path, + *, + atom_type_definitions: AtomTypeDefinitions, + pair_cutoff_definitions: PairCutoffDefinitions, + clusters_dir: str | Path | None = None, + project_dir: str | Path | None = None, + experimental_data_file: str | Path | None = None, + box_dimensions: tuple[float, float, float] | None = None, + use_pbc: bool = False, + default_cutoff: float | None = None, + shell_levels: tuple[int, ...] = (), + shared_shells: bool = False, + include_shell_atoms_in_stoichiometry: bool = False, + search_mode: str = "kdtree", + folder_start_time_fs: float | None = None, + first_frame_time_fs: float = 0.0, + frame_timestep_fs: float = 0.5, + frames_per_colormap_timestep: int = 1, + analysis_start_fs: float | None = None, + analysis_stop_fs: float | None = None, + energy_file: str | Path | None = None, + target_node_counts: tuple[int, ...] | None = None, + max_target_node_count: int | None = None, + candidates_per_size: int = 3, + prediction_population_share_threshold: float = _DEFAULT_SHARE_THRESHOLD, + q_min: float | None = None, + q_max: float | None = None, + q_points: int = _DEFAULT_Q_POINTS, + ) -> None: + self.frames_dir = Path(frames_dir).expanduser().resolve() + self.atom_type_definitions = atom_type_definitions + self.pair_cutoff_definitions = pair_cutoff_definitions + self.clusters_dir = ( + None + if clusters_dir is None + else Path(clusters_dir).expanduser().resolve() + ) + self.project_dir = ( + None + if project_dir is None + else Path(project_dir).expanduser().resolve() + ) + self.experimental_data_file = ( + None + if experimental_data_file is None + else Path(experimental_data_file).expanduser().resolve() + ) + self.box_dimensions = box_dimensions + self.use_pbc = bool(use_pbc) + self.default_cutoff = default_cutoff + self.shell_levels = tuple(int(level) for level in shell_levels) + self.shared_shells = bool(shared_shells) + self.include_shell_atoms_in_stoichiometry = bool( + include_shell_atoms_in_stoichiometry + ) + self.search_mode = str(search_mode) + self.folder_start_time_fs = folder_start_time_fs + self.first_frame_time_fs = float(first_frame_time_fs) + self.frame_timestep_fs = float(frame_timestep_fs) + self.frames_per_colormap_timestep = max( + int(frames_per_colormap_timestep), + 1, + ) + self.analysis_start_fs = analysis_start_fs + self.analysis_stop_fs = analysis_stop_fs + self.energy_file = ( + None + if energy_file is None + else Path(energy_file).expanduser().resolve() + ) + self.target_node_counts = ( + None + if target_node_counts is None + else tuple( + sorted( + { + int(value) + for value in target_node_counts + if int(value) > 0 + } + ) + ) + ) + self.max_target_node_count = ( + None + if max_target_node_count is None + else int(max_target_node_count) + ) + self.candidates_per_size = max(int(candidates_per_size), 1) + self.prediction_population_share_threshold = max( + float(prediction_population_share_threshold), + 0.0, + ) + self.q_min = None if q_min is None else float(q_min) + self.q_max = None if q_max is None else float(q_max) + self.q_points = max(int(q_points), 2) + self._project_manager = SAXSProjectManager() + + def preview_selection(self) -> ClusterDynamicsMLPreview: + dynamics_preview = ( + self._build_cluster_dynamics_workflow().preview_selection() + ) + resolved_clusters_dir, resolved_experimental, warnings = ( + self._resolve_project_inputs() + ) + structure_label_count = 0 + total_structure_files = 0 + observed_node_counts: tuple[int, ...] = () + if ( + resolved_clusters_dir is not None + and resolved_clusters_dir.is_dir() + ): + label_map = self._list_structure_labels(resolved_clusters_dir) + structure_label_count = len(label_map) + total_structure_files = sum( + len(paths) for paths in label_map.values() + ) + observed_node_counts = tuple( + sorted( + { + self._node_count_from_counts( + _parse_stoichiometry_label(label) + ) + for label in label_map + if self._node_count_from_counts( + _parse_stoichiometry_label(label) + ) + > 0 + } + ) + ) + target_counts = self._resolve_target_node_counts(observed_node_counts) + return ClusterDynamicsMLPreview( + dynamics_preview=dynamics_preview, + clusters_dir=resolved_clusters_dir, + project_dir=self.project_dir, + experimental_data_path=resolved_experimental, + structure_label_count=structure_label_count, + total_structure_files=total_structure_files, + observed_node_counts=observed_node_counts, + target_node_counts=target_counts, + warnings=warnings, + ) + + def analyze( + self, + *, + progress_callback: PredictionProgressCallback | None = None, + ) -> ClusterDynamicsMLResult: + preview = self.preview_selection() + resolved_clusters_dir = preview.clusters_dir + if resolved_clusters_dir is None or not resolved_clusters_dir.is_dir(): + raise ValueError( + "Select a cluster-structures directory, or provide a SAXSShell " + "project with a saved clusters directory before running the " + "prediction workflow." + ) + self._emit(progress_callback, "Running time-binned cluster analysis.") + dynamics_result = self._build_cluster_dynamics_workflow().analyze() + self._emit(progress_callback, "Loading structure ensembles.") + structure_observations = self._build_structure_observations( + resolved_clusters_dir + ) + self._emit( + progress_callback, "Joining lifetime and structure summaries." + ) + training_observations = self._build_training_observations( + dynamics_result, + structure_observations, + ) + if not training_observations: + raise ValueError( + "No overlapping structure labels could be matched to the " + "observed cluster labels. The prediction workflow requires " + "smaller-cluster structure files named by stoichiometry label." + ) + observed_node_counts = tuple( + sorted( + { + entry.node_count + for entry in training_observations + if entry.node_count > 0 + } + ) + ) + if not observed_node_counts: + raise ValueError( + "Could not infer any node-count observations from the structure labels." + ) + if len(observed_node_counts) < 2: + raise ValueError( + "At least two observed node counts are required for extrapolation." + ) + self._emit( + progress_callback, + "Learning bond, angle, and coordination statistics from reference structures.", + ) + self._emit(progress_callback, "Training surrogate trend models.") + predictions = self._predict_candidates( + training_observations, + self._resolve_target_node_counts(observed_node_counts), + ) + self._update_prediction_population_shares( + training_observations, predictions + ) + self._prune_prediction_population_tail(predictions) + self._emit(progress_callback, "Building surrogate SAXS comparison.") + saxs_comparison = self._build_saxs_comparison( + training_observations, + predictions, + preview.experimental_data_path, + ) + max_predicted_node_count = self._resolve_max_predicted_node_count( + predictions + ) + return ClusterDynamicsMLResult( + dynamics_result=dynamics_result, + preview=ClusterDynamicsMLPreview( + dynamics_preview=preview.dynamics_preview, + clusters_dir=preview.clusters_dir, + project_dir=preview.project_dir, + experimental_data_path=preview.experimental_data_path, + structure_label_count=preview.structure_label_count, + total_structure_files=preview.total_structure_files, + observed_node_counts=observed_node_counts, + target_node_counts=self._resolve_target_node_counts( + observed_node_counts + ), + warnings=preview.warnings, + ), + structure_observations=tuple( + sorted( + structure_observations, + key=lambda item: (item.node_count, item.label), + ) + ), + training_observations=tuple( + sorted( + training_observations, + key=lambda item: (item.node_count, item.label), + ) + ), + predictions=tuple(predictions), + saxs_comparison=saxs_comparison, + max_observed_node_count=max(observed_node_counts), + max_predicted_node_count=max_predicted_node_count, + prediction_population_share_threshold=( + self.prediction_population_share_threshold + ), + ) + + def _build_cluster_dynamics_workflow(self) -> ClusterDynamicsWorkflow: + return ClusterDynamicsWorkflow( + self.frames_dir, + atom_type_definitions=self.atom_type_definitions, + pair_cutoff_definitions=self.pair_cutoff_definitions, + box_dimensions=self.box_dimensions, + use_pbc=self.use_pbc, + default_cutoff=self.default_cutoff, + shell_levels=self.shell_levels, + shared_shells=self.shared_shells, + include_shell_atoms_in_stoichiometry=( + self.include_shell_atoms_in_stoichiometry + ), + search_mode=self.search_mode, + folder_start_time_fs=self.folder_start_time_fs, + first_frame_time_fs=self.first_frame_time_fs, + frame_timestep_fs=self.frame_timestep_fs, + frames_per_colormap_timestep=self.frames_per_colormap_timestep, + analysis_start_fs=self.analysis_start_fs, + analysis_stop_fs=self.analysis_stop_fs, + energy_file=self.energy_file, + ) + + def _resolve_project_inputs( + self, + ) -> tuple[Path | None, Path | None, tuple[str, ...]]: + warnings: list[str] = [] + resolved_clusters_dir = self.clusters_dir + resolved_experimental = self.experimental_data_file + if self.project_dir is None: + return ( + resolved_clusters_dir, + resolved_experimental, + tuple(warnings), + ) + project_file = build_project_paths(self.project_dir).project_file + if not project_file.is_file(): + warnings.append( + "The selected project directory does not contain a " + "saxs_project.json file." + ) + return ( + resolved_clusters_dir, + resolved_experimental, + tuple(warnings), + ) + try: + settings = self._project_manager.load_project(self.project_dir) + except Exception as exc: + warnings.append(f"Could not load SAXS project settings: {exc}") + return ( + resolved_clusters_dir, + resolved_experimental, + tuple(warnings), + ) + if ( + resolved_clusters_dir is None + and settings.resolved_clusters_dir is not None + ): + resolved_clusters_dir = settings.resolved_clusters_dir + if resolved_experimental is None: + try: + experimental = self._project_manager.load_experimental_data( + settings + ) + except Exception: + experimental = None + if experimental is not None: + resolved_experimental = experimental.path + return resolved_clusters_dir, resolved_experimental, tuple(warnings) + + def _list_structure_labels( + self, + clusters_dir: Path, + ) -> dict[str, list[Path]]: + label_map: dict[str, list[Path]] = defaultdict(list) + for structure_dir in sorted( + path for path in clusters_dir.iterdir() if path.is_dir() + ): + if structure_dir.name.startswith("representative_"): + continue + motif_dirs = sorted( + path + for path in structure_dir.iterdir() + if path.is_dir() and path.name.startswith("motif_") + ) + if motif_dirs: + for motif_dir in motif_dirs: + for file_path in sorted(motif_dir.iterdir()): + if file_path.suffix.lower() in {".xyz", ".pdb"}: + label_map[structure_dir.name].append(file_path) + continue + for file_path in sorted(structure_dir.iterdir()): + if file_path.suffix.lower() in {".xyz", ".pdb"}: + label_map[structure_dir.name].append(file_path) + return label_map + + def _build_structure_observations( + self, + clusters_dir: Path, + ) -> list[ClusterStructureObservation]: + grouped_paths = self._list_structure_labels(clusters_dir) + observations: list[ClusterStructureObservation] = [] + tracked_elements = self._tracked_structure_elements() + for label, paths in sorted(grouped_paths.items()): + label_counts = _parse_stoichiometry_label(label) + descriptor_rows: list[tuple[float, float, float, float, float]] = ( + [] + ) + count_rows: list[tuple[Path, dict[str, int]]] = [] + motifs = sorted( + { + path.parent.name + for path in paths + if path.parent.name.startswith("motif_") + } + ) + for path in paths: + try: + coords, elements = load_structure_file(path) + except Exception: + continue + descriptor_rows.append(_structure_descriptor_row(coords)) + filtered_counts = _filtered_structure_counts( + elements, + tracked_elements=tracked_elements, + ) + if filtered_counts: + count_rows.append((path, filtered_counts)) + if not descriptor_rows: + continue + representative_path, counts = ( + _select_representative_structure_counts( + count_rows, + fallback_path=(paths[0] if paths else None), + fallback_counts=label_counts, + ) + ) + node_count = self._node_count_from_counts(counts) + if node_count <= 0: + continue + descriptor_matrix = np.asarray(descriptor_rows, dtype=float) + observations.append( + ClusterStructureObservation( + label=label, + node_count=node_count, + element_counts=counts, + file_count=len(paths), + representative_path=representative_path, + structure_dir=clusters_dir / label, + motifs=tuple(motifs), + mean_atom_count=float(np.mean(descriptor_matrix[:, 0])), + mean_radius_of_gyration=float( + np.mean(descriptor_matrix[:, 1]) + ), + mean_max_radius=float(np.mean(descriptor_matrix[:, 2])), + mean_semiaxis_a=float(np.mean(descriptor_matrix[:, 3])), + mean_semiaxis_b=float(np.mean(descriptor_matrix[:, 4])), + mean_semiaxis_c=float(np.mean(descriptor_matrix[:, 5])), + ) + ) + return observations + + def _tracked_structure_elements(self) -> set[str]: + tracked: set[str] = set() + for definitions in self.atom_type_definitions.values(): + for element, _residue in definitions: + normalized = _normalized_element_symbol(element) + if normalized: + tracked.add(normalized) + return tracked + + def _build_training_observations( + self, + dynamics_result: ClusterDynamicsResult, + structure_observations: list[ClusterStructureObservation], + ) -> list[ClusterDynamicsMLTrainingObservation]: + lifetime_map = { + entry.label: entry for entry in dynamics_result.lifetime_by_label + } + training_rows: list[ClusterDynamicsMLTrainingObservation] = [] + for structure in structure_observations: + lifetime = lifetime_map.get(structure.label) + if lifetime is None: + lifetime = _empty_lifetime_summary( + structure.label, + cluster_size=sum(structure.element_counts.values()), + ) + training_rows.append( + ClusterDynamicsMLTrainingObservation( + label=structure.label, + node_count=structure.node_count, + cluster_size=sum(structure.element_counts.values()), + element_counts=dict(structure.element_counts), + file_count=structure.file_count, + representative_path=structure.representative_path, + structure_dir=structure.structure_dir, + motifs=structure.motifs, + mean_atom_count=structure.mean_atom_count, + mean_radius_of_gyration=structure.mean_radius_of_gyration, + mean_max_radius=structure.mean_max_radius, + mean_semiaxis_a=structure.mean_semiaxis_a, + mean_semiaxis_b=structure.mean_semiaxis_b, + mean_semiaxis_c=structure.mean_semiaxis_c, + total_observations=lifetime.total_observations, + occupied_frames=lifetime.occupied_frames, + mean_count_per_frame=lifetime.mean_count_per_frame, + occupancy_fraction=lifetime.occupancy_fraction, + association_events=lifetime.association_events, + dissociation_events=lifetime.dissociation_events, + association_rate_per_ps=lifetime.association_rate_per_ps, + dissociation_rate_per_ps=lifetime.dissociation_rate_per_ps, + completed_lifetime_count=lifetime.completed_lifetime_count, + window_truncated_lifetime_count=( + lifetime.window_truncated_lifetime_count + ), + mean_lifetime_fs=lifetime.mean_lifetime_fs, + std_lifetime_fs=lifetime.std_lifetime_fs, + ) + ) + return training_rows + + def _resolve_target_node_counts( + self, + observed_node_counts: tuple[int, ...], + ) -> tuple[int, ...]: + if self.target_node_counts is not None: + max_observed = ( + max(observed_node_counts) if observed_node_counts else 0 + ) + return tuple( + value + for value in self.target_node_counts + if value > max_observed + ) + if not observed_node_counts: + return () + max_observed = max(observed_node_counts) + max_target = ( + self.max_target_node_count + if self.max_target_node_count is not None + else max_observed + 2 + ) + max_target = max(max_target, max_observed + 1) + return tuple(range(max_observed + 1, max_target + 1)) + + def _predict_candidates( + self, + training_observations: list[ClusterDynamicsMLTrainingObservation], + target_node_counts: tuple[int, ...], + ) -> list[PredictedClusterCandidate]: + if not target_node_counts: + return [] + geometry_statistics = self._collect_training_geometry_statistics( + training_observations + ) + atom_type_by_element = dict(geometry_statistics.atom_type_by_element) + node_elements = tuple(sorted(self._atom_type_elements("node"))) + non_node_elements = tuple( + sorted( + { + element + for row in training_observations + for element in row.element_counts + if element not in node_elements + } + ) + ) + feature_matrix = np.asarray( + [ + _candidate_feature_vector( + row.element_counts, + node_elements=node_elements, + non_node_elements=non_node_elements, + ) + for row in training_observations + ], + dtype=float, + ) + weights = np.asarray( + [row.stability_weight for row in training_observations], + dtype=float, + ) + node_element_fractions = _weighted_node_element_fractions( + training_observations, + node_elements=node_elements, + ) + element_count_models = { + element: _fit_property_model( + feature_matrix, + np.asarray( + [ + row.element_counts.get(element, 0) + for row in training_observations + ], + dtype=float, + ), + weights=weights, + default_value=float( + np.average( + [ + row.element_counts.get(element, 0) + for row in training_observations + ], + weights=weights, + ) + ), + lower_bound=0.0, + ) + for element in non_node_elements + } + property_models = self._fit_candidate_property_models( + training_observations, + feature_matrix, + weights, + non_node_elements=non_node_elements, + node_elements=node_elements, + ) + predictions: list[PredictedClusterCandidate] = [] + for target_node_count in target_node_counts: + raw_candidates = self._build_raw_candidates( + training_observations, + target_node_count=target_node_count, + node_elements=node_elements, + non_node_elements=non_node_elements, + atom_type_by_element=atom_type_by_element, + node_element_fractions=node_element_fractions, + element_count_models=element_count_models, + ) + ranked_candidates: list[PredictedClusterCandidate] = [] + for counts, source_label, notes in raw_candidates: + feature_vector = _candidate_feature_vector( + counts, + node_elements=node_elements, + non_node_elements=non_node_elements, + ) + predicted_mean_count = property_models[ + "mean_count_per_frame" + ].predict(feature_vector) + predicted_occupancy = property_models[ + "occupancy_fraction" + ].predict(feature_vector) + predicted_lifetime = property_models[ + "mean_lifetime_fs" + ].predict(feature_vector) + predicted_assoc = property_models[ + "association_rate_per_ps" + ].predict(feature_vector) + predicted_dissoc = property_models[ + "dissociation_rate_per_ps" + ].predict(feature_vector) + predicted_rg = property_models[ + "mean_radius_of_gyration" + ].predict(feature_vector) + predicted_max_radius = property_models[ + "mean_max_radius" + ].predict(feature_vector) + predicted_a = property_models["mean_semiaxis_a"].predict( + feature_vector + ) + predicted_b = property_models["mean_semiaxis_b"].predict( + feature_vector + ) + predicted_c = property_models["mean_semiaxis_c"].predict( + feature_vector + ) + composition_distance = _composition_distance( + counts, + raw_candidates[0][0], + node_count=target_node_count, + ) + predicted_score = float( + max(predicted_mean_count, 0.0) + * max(predicted_lifetime, self.frame_timestep_fs) + * max(predicted_occupancy, 0.05) + / (1.0 + composition_distance) + ) + source_observation = self._best_source_observation( + training_observations, + counts=counts, + target_node_count=target_node_count, + preferred_label=source_label, + ) + generated_elements, generated_coordinates = ( + self._generate_predicted_structure( + source_observation, + target_counts=counts, + predicted_max_radius=predicted_max_radius, + geometry_statistics=geometry_statistics, + ) + ) + ranked_candidates.append( + PredictedClusterCandidate( + target_node_count=target_node_count, + rank=0, + label=stoichiometry_label(counts), + element_counts=counts, + predicted_mean_count_per_frame=float( + predicted_mean_count + ), + predicted_occupancy_fraction=float( + predicted_occupancy + ), + predicted_mean_lifetime_fs=float(predicted_lifetime), + predicted_association_rate_per_ps=float( + predicted_assoc + ), + predicted_dissociation_rate_per_ps=float( + predicted_dissoc + ), + predicted_mean_radius_of_gyration=float(predicted_rg), + predicted_mean_max_radius=float(predicted_max_radius), + predicted_mean_semiaxis_a=float(predicted_a), + predicted_mean_semiaxis_b=float(predicted_b), + predicted_mean_semiaxis_c=float(predicted_c), + predicted_population_share=0.0, + predicted_stability_score=float(predicted_score), + source_label=( + None + if source_observation is None + else source_observation.label + ), + notes=notes, + generated_elements=tuple(generated_elements), + generated_coordinates=np.asarray( + generated_coordinates, + dtype=float, + ), + ) + ) + ranked_candidates.sort( + key=lambda item: ( + -item.predicted_stability_score, + item.label, + ) + ) + for rank, candidate in enumerate( + ranked_candidates[: self.candidates_per_size], + start=1, + ): + candidate.rank = rank + predictions.append(candidate) + return predictions + + def _fit_candidate_property_models( + self, + training_observations: list[ClusterDynamicsMLTrainingObservation], + feature_matrix: np.ndarray, + weights: np.ndarray, + *, + non_node_elements: tuple[str, ...], + node_elements: tuple[str, ...], + ) -> dict[str, _PropertyModel]: + del non_node_elements + del node_elements + finite_lifetimes = [ + row.mean_lifetime_fs + for row in training_observations + if row.mean_lifetime_fs is not None + ] + default_lifetime = ( + float(np.mean(finite_lifetimes)) + if finite_lifetimes + else self.frame_timestep_fs + ) + return { + "mean_count_per_frame": _fit_property_model( + feature_matrix, + np.asarray( + [ + row.mean_count_per_frame + for row in training_observations + ], + dtype=float, + ), + weights=weights, + default_value=max( + float( + np.average( + [ + row.mean_count_per_frame + for row in training_observations + ], + weights=weights, + ) + ), + 0.0, + ), + transform="log1p", + lower_bound=0.0, + ), + "occupancy_fraction": _fit_property_model( + feature_matrix, + np.asarray( + [row.occupancy_fraction for row in training_observations], + dtype=float, + ), + weights=weights, + default_value=float( + np.average( + [ + row.occupancy_fraction + for row in training_observations + ], + weights=weights, + ) + ), + lower_bound=0.0, + upper_bound=1.0, + ), + "mean_lifetime_fs": _fit_property_model( + feature_matrix, + np.asarray( + [ + ( + default_lifetime + if row.mean_lifetime_fs is None + else row.mean_lifetime_fs + ) + for row in training_observations + ], + dtype=float, + ), + weights=weights, + default_value=max(default_lifetime, self.frame_timestep_fs), + transform="log1p", + lower_bound=self.frame_timestep_fs, + ), + "association_rate_per_ps": _fit_property_model( + feature_matrix, + np.asarray( + [ + row.association_rate_per_ps + for row in training_observations + ], + dtype=float, + ), + weights=weights, + default_value=float( + np.average( + [ + row.association_rate_per_ps + for row in training_observations + ], + weights=weights, + ) + ), + transform="log1p", + lower_bound=0.0, + ), + "dissociation_rate_per_ps": _fit_property_model( + feature_matrix, + np.asarray( + [ + row.dissociation_rate_per_ps + for row in training_observations + ], + dtype=float, + ), + weights=weights, + default_value=float( + np.average( + [ + row.dissociation_rate_per_ps + for row in training_observations + ], + weights=weights, + ) + ), + transform="log1p", + lower_bound=0.0, + ), + "mean_radius_of_gyration": _fit_property_model( + feature_matrix, + np.asarray( + [ + row.mean_radius_of_gyration + for row in training_observations + ], + dtype=float, + ), + weights=weights, + default_value=float( + np.average( + [ + row.mean_radius_of_gyration + for row in training_observations + ], + weights=weights, + ) + ), + transform="log1p", + lower_bound=0.1, + ), + "mean_max_radius": _fit_property_model( + feature_matrix, + np.asarray( + [row.mean_max_radius for row in training_observations], + dtype=float, + ), + weights=weights, + default_value=float( + np.average( + [row.mean_max_radius for row in training_observations], + weights=weights, + ) + ), + transform="log1p", + lower_bound=0.1, + ), + "mean_semiaxis_a": _fit_property_model( + feature_matrix, + np.asarray( + [row.mean_semiaxis_a for row in training_observations], + dtype=float, + ), + weights=weights, + default_value=float( + np.average( + [row.mean_semiaxis_a for row in training_observations], + weights=weights, + ) + ), + transform="log1p", + lower_bound=0.05, + ), + "mean_semiaxis_b": _fit_property_model( + feature_matrix, + np.asarray( + [row.mean_semiaxis_b for row in training_observations], + dtype=float, + ), + weights=weights, + default_value=float( + np.average( + [row.mean_semiaxis_b for row in training_observations], + weights=weights, + ) + ), + transform="log1p", + lower_bound=0.05, + ), + "mean_semiaxis_c": _fit_property_model( + feature_matrix, + np.asarray( + [row.mean_semiaxis_c for row in training_observations], + dtype=float, + ), + weights=weights, + default_value=float( + np.average( + [row.mean_semiaxis_c for row in training_observations], + weights=weights, + ) + ), + transform="log1p", + lower_bound=0.05, + ), + } + + def _build_raw_candidates( + self, + training_observations: list[ClusterDynamicsMLTrainingObservation], + *, + target_node_count: int, + node_elements: tuple[str, ...], + non_node_elements: tuple[str, ...], + atom_type_by_element: dict[str, str], + node_element_fractions: dict[str, float], + element_count_models: dict[str, _PropertyModel], + ) -> list[tuple[dict[str, int], str | None, str]]: + candidates: list[tuple[dict[str, int], str | None, str]] = [] + required_non_node_counts = _required_non_node_count_floors( + training_observations, + target_node_count=target_node_count, + non_node_elements=non_node_elements, + atom_type_by_element=atom_type_by_element, + ) + base_counts = _allocate_node_counts( + target_node_count, + node_element_fractions, + ) + for element in non_node_elements: + predicted = int( + round( + element_count_models[element].predict( + _candidate_feature_vector( + base_counts, + node_elements=node_elements, + non_node_elements=non_node_elements, + ) + ) + ) + ) + if predicted > 0: + base_counts[element] = predicted + candidates.append((base_counts, None, "Trend extrapolation")) + ranked_observations = sorted( + training_observations, + key=lambda item: ( + -item.stability_weight, + -item.node_count, + item.label, + ), + ) + for row in ranked_observations: + scaled = _allocate_node_counts( + target_node_count, node_element_fractions + ) + if row.node_count <= 0: + continue + scale_factor = float(target_node_count) / float(row.node_count) + for element in non_node_elements: + count = int( + round(row.element_counts.get(element, 0) * scale_factor) + ) + if count > 0: + scaled[element] = count + candidates.append( + ( + scaled, + row.label, + f"Composition scaled from observed {row.label}", + ) + ) + deduplicated: list[tuple[dict[str, int], str | None, str]] = [] + seen_labels: set[str] = set() + for counts, source_label, notes in candidates: + normalized = _normalized_counts(counts) + if any( + normalized.get(element, 0) < minimum_count + for element, minimum_count in required_non_node_counts.items() + ): + continue + if not _candidate_has_support( + training_observations, + counts=normalized, + target_node_count=target_node_count, + node_elements=node_elements, + non_node_elements=non_node_elements, + ): + continue + label = stoichiometry_label(normalized) + if label in seen_labels: + continue + seen_labels.add(label) + deduplicated.append((normalized, source_label, notes)) + return deduplicated + + def _best_source_observation( + self, + training_observations: list[ClusterDynamicsMLTrainingObservation], + *, + counts: dict[str, int], + target_node_count: int, + preferred_label: str | None, + ) -> ClusterDynamicsMLTrainingObservation | None: + if preferred_label is not None: + for row in training_observations: + if row.label == preferred_label: + return row + best_row: ClusterDynamicsMLTrainingObservation | None = None + best_score: tuple[float, float, str] | None = None + for row in training_observations: + if row.representative_path is None: + continue + node_gap = abs(target_node_count - row.node_count) + composition_gap = _composition_distance( + counts, + row.element_counts, + node_count=max(target_node_count, 1), + ) + score = (float(node_gap), float(composition_gap), row.label) + if best_score is None or score < best_score: + best_score = score + best_row = row + return best_row + + def _generate_predicted_structure( + self, + source_observation: ClusterDynamicsMLTrainingObservation | None, + *, + target_counts: dict[str, int], + predicted_max_radius: float, + geometry_statistics: _TrainingGeometryStatistics, + ) -> tuple[list[str], np.ndarray]: + node_elements = self._atom_type_elements("node") + seed_node_elements: list[str] = [] + seed_node_coordinates = np.zeros((0, 3), dtype=float) + if ( + source_observation is not None + and source_observation.representative_path is not None + ): + try: + source_coords, source_elements = load_structure_file( + source_observation.representative_path + ) + except Exception: + source_coords = None + source_elements = None + if source_coords is not None and source_elements is not None: + source_coords_array = np.asarray(source_coords, dtype=float) + normalized_elements = [ + _normalized_element_symbol(element) + for element in source_elements + ] + node_indices = [ + index + for index, element in enumerate(normalized_elements) + if element in node_elements + ] + if node_indices: + seed_node_coordinates = np.asarray( + source_coords_array[node_indices], + dtype=float, + ) + seed_node_elements = [ + normalized_elements[index] for index in node_indices + ] + generated_elements, generated = _build_geometry_guided_structure( + target_counts, + node_elements=tuple(sorted(node_elements)), + pair_cutoff_definitions=self.pair_cutoff_definitions, + geometry_statistics=geometry_statistics, + predicted_max_radius=predicted_max_radius, + seed_node_elements=seed_node_elements, + seed_node_coordinates=seed_node_coordinates, + ) + if generated.size == 0: + return _build_fallback_structure(target_counts) + return generated_elements, generated + + def _build_saxs_comparison( + self, + training_observations: list[ClusterDynamicsMLTrainingObservation], + predictions: list[PredictedClusterCandidate], + experimental_data_path: Path | None, + ) -> ClusterDynamicsMLSAXSComparison | None: + experimental_data = self._load_experimental_data( + experimental_data_path + ) + q_values = self._resolve_q_values(experimental_data) + component_dir, surrogate_structure_dir = ( + self._resolve_saxs_artifact_dirs() + ) + observed_weights, predicted_weights = _resolved_population_weights( + training_observations, + predictions, + frame_timestep_fs=self.frame_timestep_fs, + ) + predicted_structure_paths = [ + _write_predicted_structure_file( + prediction, + surrogate_structure_dir=surrogate_structure_dir, + ) + for prediction in predictions + ] + observed_components = [ + component + for row, weight in zip( + training_observations, + observed_weights, + strict=False, + ) + for component in [ + self._build_observed_saxs_component( + row, + weight=float(weight), + q_values=q_values, + component_dir=component_dir, + ) + ] + if component is not None + ] + predicted_components = [ + component + for prediction, weight, structure_path in zip( + predictions, + predicted_weights, + predicted_structure_paths, + strict=False, + ) + for component in [ + self._build_predicted_saxs_component( + prediction, + weight=float(weight), + structure_path=structure_path, + q_values=q_values, + component_dir=component_dir, + ) + ] + if component is not None + ] + all_components = [*observed_components, *predicted_components] + if not all_components: + return None + + observed_model = self._compose_weighted_model(observed_components) + combined_model = self._compose_weighted_model(all_components) + if combined_model is None: + return None + + observed_raw_model = ( + None if observed_model is None else observed_model[0] + ) + raw_model, normalized_weights = combined_model + experimental_intensity: np.ndarray | None = None + observed_fitted_model = ( + None if observed_raw_model is None else observed_raw_model.copy() + ) + observed_rmse: float | None = None + scale_factor = 1.0 + offset = 0.0 + fitted_model = raw_model.copy() + residuals: np.ndarray | None = None + rmse: float | None = None + if experimental_data is not None: + experimental_intensity = np.asarray( + experimental_data.intensities, + dtype=float, + ) + if observed_raw_model is not None: + ( + _observed_scale_factor, + _observed_offset, + observed_fitted_model, + _observed_residuals, + observed_rmse, + ) = _fit_model_to_experimental( + observed_raw_model, + experimental_intensity, + ) + ( + scale_factor, + offset, + fitted_model, + residuals, + rmse, + ) = _fit_model_to_experimental( + raw_model, + experimental_intensity, + ) + return ClusterDynamicsMLSAXSComparison( + q_values=q_values, + observed_raw_model_intensity=observed_raw_model, + observed_fitted_model_intensity=observed_fitted_model, + observed_rmse=observed_rmse, + raw_model_intensity=raw_model, + fitted_model_intensity=fitted_model, + experimental_intensity=experimental_intensity, + residuals=residuals, + scale_factor=float(scale_factor), + offset=float(offset), + rmse=rmse, + component_weights=tuple( + SAXSComponentWeight( + label=component.label, + weight=float(weight), + source=component.source, + profile_path=component.profile_path, + structure_path=component.structure_path, + ) + for component, weight in zip( + all_components, + normalized_weights, + strict=False, + ) + ), + experimental_data_path=( + None if experimental_data is None else experimental_data.path + ), + component_output_dir=component_dir, + surrogate_structure_dir=surrogate_structure_dir, + ) + + def _resolve_q_values( + self, + experimental_data: ExperimentalDataSummary | None, + ) -> np.ndarray: + if experimental_data is not None: + return np.asarray(experimental_data.q_values, dtype=float) + supported_range = None + if self.project_dir is not None: + try: + supported_range = load_built_component_q_range( + self.project_dir + ) + except Exception: + supported_range = None + q_min = ( + self.q_min + if self.q_min is not None + else ( + float(supported_range[0]) + if supported_range is not None + else _DEFAULT_Q_MIN + ) + ) + q_max = ( + self.q_max + if self.q_max is not None + else ( + float(supported_range[1]) + if supported_range is not None + else _DEFAULT_Q_MAX + ) + ) + if q_min > q_max: + raise ValueError("q min must be less than or equal to q max.") + return np.linspace(float(q_min), float(q_max), int(self.q_points)) + + def _resolve_saxs_artifact_dirs(self) -> tuple[Path, Path]: + if self.project_dir is not None: + base_dir = ( + build_project_paths(self.project_dir).exported_data_dir + / "clusterdynamicsml" + ) + else: + base_dir = ( + self.frames_dir.parent + / f"{self.frames_dir.name}_clusterdynamicsml" + ) + component_dir = base_dir / "saxs_components" + surrogate_structure_dir = base_dir / "surrogate_structures" + component_dir.mkdir(parents=True, exist_ok=True) + surrogate_structure_dir.mkdir(parents=True, exist_ok=True) + return component_dir, surrogate_structure_dir + + def _build_observed_saxs_component( + self, + observation: ClusterDynamicsMLTrainingObservation, + *, + weight: float, + q_values: np.ndarray, + component_dir: Path, + ) -> _ResolvedSAXSComponent | None: + if weight <= 0.0: + return None + + profile_path = ( + component_dir + / f"observed_{_safe_component_stem(observation.label)}.txt" + ) + project_trace = self._load_project_component_trace( + observation.label, + q_values=q_values, + ) + if project_trace is not None: + _write_component_profile( + profile_path, + q_values=q_values, + intensity=project_trace, + source=f"project:{observation.label}", + ) + return _ResolvedSAXSComponent( + label=observation.label, + weight=weight, + source="observed_project", + trace=project_trace, + profile_path=profile_path, + structure_path=observation.representative_path, + ) + + if observation.representative_path is None: + return None + try: + coords, elements = load_structure_file( + observation.representative_path + ) + trace = np.asarray( + compute_debye_intensity(coords, elements, q_values), + dtype=float, + ) + except Exception: + return None + _write_component_profile( + profile_path, + q_values=q_values, + intensity=trace, + source=f"direct:{observation.representative_path.name}", + ) + return _ResolvedSAXSComponent( + label=observation.label, + weight=weight, + source="observed_direct", + trace=trace, + profile_path=profile_path, + structure_path=observation.representative_path, + ) + + def _build_predicted_saxs_component( + self, + prediction: PredictedClusterCandidate, + *, + weight: float, + structure_path: Path, + q_values: np.ndarray, + component_dir: Path, + ) -> _ResolvedSAXSComponent | None: + if weight <= 0.0: + return None + try: + trace = np.asarray( + compute_debye_intensity( + prediction.generated_coordinates, + list(prediction.generated_elements), + q_values, + ), + dtype=float, + ) + except Exception: + return None + + file_stem = _predicted_structure_file_stem(prediction) + profile_path = component_dir / f"predicted_{file_stem}.txt" + _write_component_profile( + profile_path, + q_values=q_values, + intensity=trace, + source="predicted_surrogate", + ) + return _ResolvedSAXSComponent( + label=prediction.label, + weight=weight, + source="predicted", + trace=trace, + profile_path=profile_path, + structure_path=structure_path, + ) + + def _load_project_component_trace( + self, + label: str, + *, + q_values: np.ndarray, + ) -> np.ndarray | None: + if self.project_dir is None: + return None + project_dir = Path(self.project_dir).expanduser().resolve() + map_path = project_dir / "md_saxs_map.json" + if not map_path.is_file(): + return None + + try: + map_payload = json.loads(map_path.read_text(encoding="utf-8")) + except Exception: + return None + saxs_map = map_payload.get("saxs_map", {}) + motif_map = saxs_map.get(label) + if not isinstance(motif_map, dict) or not motif_map: + return None + + prior_payload: dict[str, object] = {} + prior_path = project_dir / "md_prior_weights.json" + if prior_path.is_file(): + try: + prior_payload = json.loads( + prior_path.read_text(encoding="utf-8") + ) + except Exception: + prior_payload = {} + structures_payload = prior_payload.get("structures", {}) + structure_weights = ( + structures_payload.get(label, {}) + if isinstance(structures_payload, dict) + else {} + ) + + component_dir = build_project_paths( + project_dir + ).scattering_components_dir + motif_traces: list[np.ndarray] = [] + motif_weights: list[float] = [] + for motif in sorted(motif_map): + profile_file = str(motif_map.get(motif, "")).strip() + if not profile_file: + continue + profile_path = component_dir / profile_file + if not profile_path.is_file(): + continue + try: + raw_data = np.loadtxt(profile_path, comments="#") + except Exception: + continue + if raw_data.size == 0: + continue + if raw_data.ndim == 1: + raw_data = raw_data.reshape(1, -1) + component_q = np.asarray(raw_data[:, 0], dtype=float) + component_i = np.asarray(raw_data[:, 1], dtype=float) + if component_q.size == 0 or component_i.size == 0: + continue + trace = np.interp(q_values, component_q, component_i) + motif_traces.append(np.asarray(trace, dtype=float)) + motif_info = ( + structure_weights.get(motif, {}) + if isinstance(structure_weights, dict) + else {} + ) + motif_weight = 0.0 + if isinstance(motif_info, dict): + count_value = motif_info.get("count") + weight_value = motif_info.get("weight") + if count_value is not None: + motif_weight = max(float(count_value), 0.0) + elif weight_value is not None: + motif_weight = max(float(weight_value), 0.0) + motif_weights.append(motif_weight) + if not motif_traces: + return None + weights = np.asarray(motif_weights, dtype=float) + if np.sum(weights) <= 0.0: + weights = np.ones(len(motif_traces), dtype=float) + weights = weights / np.sum(weights) + stacked = np.asarray(motif_traces, dtype=float) + return np.asarray(np.einsum("i,ij->j", weights, stacked), dtype=float) + + @staticmethod + def _compose_weighted_model( + components: list[_ResolvedSAXSComponent], + ) -> tuple[np.ndarray, np.ndarray] | None: + if not components: + return None + weights = np.asarray( + [component.weight for component in components], dtype=float + ) + if np.sum(weights) <= 0.0: + return None + normalized_weights = weights / np.sum(weights) + stacked = np.asarray( + [component.trace for component in components], dtype=float + ) + raw_model = np.einsum("i,ij->j", normalized_weights, stacked) + return np.asarray(raw_model, dtype=float), normalized_weights + + def _load_experimental_data( + self, + experimental_data_path: Path | None, + ) -> ExperimentalDataSummary | None: + if experimental_data_path is None: + return None + try: + return load_experimental_data_file(experimental_data_path) + except Exception: + return None + + def _update_prediction_population_shares( + self, + training_observations: list[ClusterDynamicsMLTrainingObservation], + predictions: list[PredictedClusterCandidate], + ) -> None: + del training_observations + total_weight = sum( + _prediction_share_weight( + item, + frame_timestep_fs=self.frame_timestep_fs, + ) + for item in predictions + ) + total_weight = max(float(total_weight), 1e-12) + for item in predictions: + weight = _prediction_share_weight( + item, + frame_timestep_fs=self.frame_timestep_fs, + ) + item.predicted_population_share = float(weight / total_weight) + + def _prune_prediction_population_tail( + self, + predictions: list[PredictedClusterCandidate], + ) -> None: + if not predictions: + return + threshold = max(float(self.prediction_population_share_threshold), 0.0) + grouped: defaultdict[int, list[PredictedClusterCandidate]] = ( + defaultdict(list) + ) + for item in predictions: + grouped[int(item.target_node_count)].append(item) + + retained: list[PredictedClusterCandidate] = [] + removed_any = False + for target_node_count in sorted(grouped): + ordered = sorted( + grouped[target_node_count], + key=lambda item: ( + -float(item.predicted_population_share), + -float(item.predicted_stability_score), + item.label, + ), + ) + for index, item in enumerate(ordered): + if index == 0 or threshold <= 0.0: + retained.append(item) + continue + if float(item.predicted_population_share) >= threshold: + retained.append(item) + continue + removed_any = True + if removed_any: + predictions[:] = retained + self._update_prediction_population_shares([], predictions) + self._reassign_prediction_ranks(predictions) + + @staticmethod + def _reassign_prediction_ranks( + predictions: list[PredictedClusterCandidate], + ) -> None: + grouped: defaultdict[int, list[PredictedClusterCandidate]] = ( + defaultdict(list) + ) + for item in predictions: + grouped[int(item.target_node_count)].append(item) + reordered: list[PredictedClusterCandidate] = [] + for target_node_count in sorted(grouped, reverse=True): + ordered = sorted( + grouped[target_node_count], + key=lambda item: ( + -float(item.predicted_population_share), + -float(item.predicted_stability_score), + item.label, + ), + ) + for rank, item in enumerate(ordered, start=1): + item.rank = rank + reordered.append(item) + predictions[:] = reordered + + def _resolve_max_predicted_node_count( + self, + predictions: list[PredictedClusterCandidate], + ) -> int | None: + qualifying = [ + item.target_node_count + for item in predictions + if item.predicted_population_share + >= self.prediction_population_share_threshold + and item.predicted_mean_lifetime_fs >= self.frame_timestep_fs + ] + return None if not qualifying else max(qualifying) + + def _node_count_from_counts(self, counts: dict[str, int]) -> int: + node_elements = self._atom_type_elements("node") + if node_elements: + return int( + sum(counts.get(element, 0) for element in node_elements) + ) + return int(sum(counts.values())) + + def _atom_type_elements(self, atom_type: str) -> set[str]: + return { + _normalized_element_symbol(element) + for element, _residue in self.atom_type_definitions.get( + atom_type, [] + ) + if _normalized_element_symbol(element) + } + + def _atom_type_by_element(self) -> dict[str, str]: + mapping: dict[str, str] = {} + for atom_type, definitions in self.atom_type_definitions.items(): + geometry_type = _geometry_atom_type_label(atom_type) + for element, _residue in definitions: + normalized = _normalized_element_symbol(element) + if normalized and normalized not in mapping: + mapping[normalized] = geometry_type + return mapping + + def _structure_files_for_observation( + self, + observation: ClusterDynamicsMLTrainingObservation, + ) -> list[Path]: + structure_dir = Path(observation.structure_dir).expanduser() + if structure_dir.is_dir(): + motif_dirs = sorted( + path + for path in structure_dir.iterdir() + if path.is_dir() and path.name.startswith("motif_") + ) + if motif_dirs: + file_paths = [ + file_path + for motif_dir in motif_dirs + for file_path in _structure_files_in_directory(motif_dir) + ] + if file_paths: + return file_paths + file_paths = _structure_files_in_directory(structure_dir) + if file_paths: + return file_paths + representative_path = ( + None + if observation.representative_path is None + else Path(observation.representative_path).expanduser() + ) + if representative_path is not None and representative_path.is_file(): + return [representative_path] + return [] + + def _collect_training_geometry_statistics( + self, + training_observations: list[ClusterDynamicsMLTrainingObservation], + ) -> _TrainingGeometryStatistics: + atom_type_by_element = self._atom_type_by_element() + node_elements = tuple(sorted(self._atom_type_elements("node"))) + tracked_atom_types = tuple( + sorted( + { + "node", + *atom_type_by_element.values(), + } + ) + ) + node_bond_lengths: list[float] = [] + bond_lengths: defaultdict[tuple[str, str], list[float]] = defaultdict( + list + ) + seed_contact_distances: defaultdict[tuple[str, str], list[float]] = ( + defaultdict(list) + ) + node_angles: defaultdict[tuple[str, str], list[float]] = defaultdict( + list + ) + node_coordination: defaultdict[str, list[int]] = defaultdict(list) + non_node_node_coordination: defaultdict[str, list[int]] = defaultdict( + list + ) + parsed_structures: list[tuple[np.ndarray, list[str], list[str]]] = [] + + for observation in training_observations: + for file_path in self._structure_files_for_observation( + observation + ): + try: + coords, raw_elements = load_structure_file(file_path) + except Exception: + continue + coordinates = np.asarray(coords, dtype=float) + elements = [ + _normalized_element_symbol(element) + for element in raw_elements + ] + geometry_types = [ + atom_type_by_element.get(element, "shell") + for element in elements + ] + node_indices = [ + index + for index, element in enumerate(elements) + if element in node_elements + ] + if not node_indices: + continue + parsed_structures.append( + (coordinates, elements, geometry_types) + ) + nearest_pair_distances: dict[tuple[int, str, str], float] = {} + for atom_index, element in enumerate(elements): + for other_index, other_element in enumerate(elements): + if atom_index == other_index: + continue + pair_key = _sorted_pair_key(element, other_element) + distance = float( + np.linalg.norm( + coordinates[atom_index] + - coordinates[other_index] + ) + ) + previous = nearest_pair_distances.get( + (atom_index, *pair_key) + ) + if previous is None or distance < previous: + nearest_pair_distances[(atom_index, *pair_key)] = ( + distance + ) + for key, distance in nearest_pair_distances.items(): + _atom_index, pair_a, pair_b = key + seed_contact_distances[(pair_a, pair_b)].append( + float(distance) + ) + node_coordinates = coordinates[node_indices] + node_edge_pairs = _node_scaffold_edges( + node_coordinates, + [elements[index] for index in node_indices], + pair_cutoff_definitions=self.pair_cutoff_definitions, + ) + node_neighbors: dict[int, list[int]] = defaultdict(list) + for local_index_a, local_index_b in node_edge_pairs: + global_index_a = node_indices[local_index_a] + global_index_b = node_indices[local_index_b] + node_neighbors[global_index_a].append(global_index_b) + node_neighbors[global_index_b].append(global_index_a) + distance = float( + np.linalg.norm( + coordinates[global_index_a] + - coordinates[global_index_b] + ) + ) + if distance > 0.0: + node_bond_lengths.append(distance) + bond_lengths[ + _sorted_pair_key( + elements[global_index_a], + elements[global_index_b], + ) + ].append(distance) + + attached_nodes_by_atom = _associate_non_node_atoms_to_nodes( + coordinates, + elements=elements, + node_indices=node_indices, + pair_cutoff_definitions=self.pair_cutoff_definitions, + ) + attached_atoms_by_node: dict[int, list[int]] = defaultdict( + list + ) + for ( + atom_index, + attached_nodes, + ) in attached_nodes_by_atom.items(): + if not attached_nodes: + continue + non_node_node_coordination[elements[atom_index]].append( + len(attached_nodes) + ) + for node_index in attached_nodes: + attached_atoms_by_node[node_index].append(atom_index) + distance = float( + np.linalg.norm( + coordinates[atom_index] + - coordinates[node_index] + ) + ) + if distance > 0.0: + bond_lengths[ + _sorted_pair_key( + elements[atom_index], elements[node_index] + ) + ].append(distance) + + for node_index in node_indices: + neighbor_vectors: list[tuple[str, np.ndarray]] = [] + coordination_counts = { + atom_type: 0 for atom_type in tracked_atom_types + } + coordination_counts["node"] = len( + node_neighbors[node_index] + ) + for neighbor_index in node_neighbors[node_index]: + neighbor_vectors.append( + ( + "node", + coordinates[neighbor_index] + - coordinates[node_index], + ) + ) + for atom_index in attached_atoms_by_node[node_index]: + geometry_type = atom_type_by_element.get( + elements[atom_index], + "shell", + ) + coordination_counts[geometry_type] = ( + coordination_counts.get(geometry_type, 0) + 1 + ) + neighbor_vectors.append( + ( + geometry_type, + coordinates[atom_index] + - coordinates[node_index], + ) + ) + for atom_type, count in coordination_counts.items(): + node_coordination[atom_type].append(int(count)) + for ( + (neighbor_type_a, vector_a), + (neighbor_type_b, vector_b), + ) in combinations(neighbor_vectors, 2): + angle = _angle_between_vectors(vector_a, vector_b) + if angle is None: + continue + node_angles[ + _sorted_pair_key(neighbor_type_a, neighbor_type_b) + ].append(angle) + + default_node_bond_length = _fallback_node_bond_length( + node_elements=node_elements, + pair_cutoff_definitions=self.pair_cutoff_definitions, + ) + preliminary_contact_medians = { + pair: _median_or_default( + values, + default=_fallback_pair_distance( + pair[0], + pair[1], + pair_cutoff_definitions=self.pair_cutoff_definitions, + default=default_node_bond_length, + ), + ) + for pair, values in seed_contact_distances.items() + if values + } + contact_distances: defaultdict[tuple[str, str], list[float]] = ( + defaultdict(list) + ) + geometry_contact_distances: defaultdict[ + tuple[str, str], list[float] + ] = defaultdict(list) + atom_coordination: defaultdict[tuple[str, str], list[int]] = ( + defaultdict(list) + ) + for coordinates, elements, geometry_types in parsed_structures: + neighbor_counts = [Counter() for _ in range(len(elements))] + for atom_index_a, atom_index_b in combinations( + range(len(elements)), 2 + ): + pair_key = _sorted_pair_key( + elements[atom_index_a], + elements[atom_index_b], + ) + distance = float( + np.linalg.norm( + coordinates[atom_index_a] - coordinates[atom_index_b] + ) + ) + if distance > _contact_distance_cutoff( + elements[atom_index_a], + elements[atom_index_b], + preliminary_contact_medians=preliminary_contact_medians, + default_distance=default_node_bond_length, + pair_cutoff_definitions=self.pair_cutoff_definitions, + ): + continue + contact_distances[pair_key].append(distance) + geometry_pair_key = _sorted_pair_key( + geometry_types[atom_index_a], + geometry_types[atom_index_b], + ) + geometry_contact_distances[geometry_pair_key].append(distance) + neighbor_counts[atom_index_a][ + geometry_types[atom_index_b] + ] += 1 + neighbor_counts[atom_index_b][ + geometry_types[atom_index_a] + ] += 1 + for atom_index, center_type in enumerate(geometry_types): + for neighbor_type in tracked_atom_types: + atom_coordination[(center_type, neighbor_type)].append( + int(neighbor_counts[atom_index].get(neighbor_type, 0)) + ) + return _TrainingGeometryStatistics( + atom_type_by_element=atom_type_by_element, + node_elements=node_elements, + tracked_atom_types=tracked_atom_types, + node_bond_length=_median_or_default( + node_bond_lengths, + default=default_node_bond_length, + ), + bond_length_medians={ + pair: _median_or_default( + values, + default=_fallback_pair_distance( + pair[0], + pair[1], + pair_cutoff_definitions=self.pair_cutoff_definitions, + default=default_node_bond_length, + ), + ) + for pair, values in bond_lengths.items() + if values + }, + contact_distance_medians={ + pair: _median_or_default( + values, + default=_fallback_pair_distance( + pair[0], + pair[1], + pair_cutoff_definitions=self.pair_cutoff_definitions, + default=default_node_bond_length, + ), + ) + for pair, values in contact_distances.items() + if values + }, + geometry_contact_distance_medians={ + pair: _median_or_default( + values, + default=default_node_bond_length, + ) + for pair, values in geometry_contact_distances.items() + if values + }, + node_angle_medians={ + pair: _median_or_default( + values, + default=180.0 if pair == ("node", "node") else 120.0, + ) + for pair, values in node_angles.items() + if values + }, + node_coordination_medians={ + atom_type: _median_or_default(values, default=0.0) + for atom_type, values in node_coordination.items() + }, + non_node_node_coordination_medians={ + element: _median_or_default(values, default=1.0) + for element, values in non_node_node_coordination.items() + }, + atom_coordination_medians={ + pair: _median_or_default(values, default=0.0) + for pair, values in atom_coordination.items() + }, + ) + + @staticmethod + def _emit( + callback: PredictionProgressCallback | None, + message: str, + ) -> None: + if callback is not None: + callback(str(message)) + + +def _fit_model_to_experimental( + raw_model: np.ndarray, + experimental_intensity: np.ndarray, +) -> tuple[float, float, np.ndarray, np.ndarray, float]: + design = np.column_stack([raw_model, np.ones_like(raw_model)]) + try: + scale_factor, offset = np.linalg.lstsq( + design, + experimental_intensity, + rcond=None, + )[0] + except np.linalg.LinAlgError: + scale_factor, offset = (1.0, 0.0) + fitted_model = raw_model * float(scale_factor) + float(offset) + residuals = experimental_intensity - fitted_model + rmse = float(np.sqrt(np.mean(residuals**2))) + return ( + float(scale_factor), + float(offset), + np.asarray(fitted_model, dtype=float), + np.asarray(residuals, dtype=float), + rmse, + ) + + +def _saxs_component_weight( + mean_count_per_frame: float, occupancy_fraction: float +) -> float: + return max(float(mean_count_per_frame), 0.0) * max( + float(occupancy_fraction), 0.05 + ) + + +def _resolved_population_weights( + training_observations: ( + list[ClusterDynamicsMLTrainingObservation] + | tuple[ClusterDynamicsMLTrainingObservation, ...] + ), + predictions: ( + list[PredictedClusterCandidate] | tuple[PredictedClusterCandidate, ...] + ), + *, + frame_timestep_fs: float, +) -> tuple[np.ndarray, np.ndarray]: + observed_weights = np.asarray( + [ + _saxs_component_weight( + observation.mean_count_per_frame, + observation.occupancy_fraction, + ) + for observation in training_observations + ], + dtype=float, + ) + if not predictions: + return observed_weights, np.zeros(0, dtype=float) + + prediction_base_weights = np.asarray( + [ + _saxs_component_weight( + prediction.predicted_mean_count_per_frame, + prediction.predicted_occupancy_fraction, + ) + for prediction in predictions + ], + dtype=float, + ) + prediction_shares = np.asarray( + [ + max(float(prediction.predicted_population_share), 0.0) + for prediction in predictions + ], + dtype=float, + ) + if np.sum(prediction_shares) <= 0.0: + prediction_shares = np.asarray( + [ + _prediction_share_weight( + prediction, + frame_timestep_fs=frame_timestep_fs, + ) + for prediction in predictions + ], + dtype=float, + ) + total_prediction_share = float(np.sum(prediction_shares)) + if total_prediction_share <= 0.0: + return observed_weights, prediction_base_weights + normalized_prediction_shares = prediction_shares / total_prediction_share + + predicted_total_weight = _predicted_total_weight_from_observed_tail( + training_observations, + observed_weights, + predictions, + ) + if predicted_total_weight <= 0.0: + positive_observed = observed_weights[observed_weights > 0.0] + positive_prediction_total = float(np.sum(prediction_base_weights)) + if positive_observed.size > 0: + predicted_total_weight = float(np.min(positive_observed)) + elif positive_prediction_total > 0.0: + predicted_total_weight = positive_prediction_total + else: + predicted_total_weight = 1.0 + if predicted_total_weight <= 0.0: + return observed_weights, prediction_base_weights + return ( + observed_weights, + np.asarray( + normalized_prediction_shares * predicted_total_weight, + dtype=float, + ), + ) + + +def _predicted_total_weight_from_observed_tail( + training_observations: ( + list[ClusterDynamicsMLTrainingObservation] + | tuple[ClusterDynamicsMLTrainingObservation, ...] + ), + observed_weights: np.ndarray, + predictions: ( + list[PredictedClusterCandidate] | tuple[PredictedClusterCandidate, ...] + ), +) -> float: + positive_observed = observed_weights[observed_weights > 0.0] + if positive_observed.size == 0: + return 0.0 + + observed_size_totals: defaultdict[int, float] = defaultdict(float) + for observation, weight in zip( + training_observations, + observed_weights, + strict=False, + ): + weight_value = max(float(weight), 0.0) + if weight_value <= 0.0: + continue + observed_size_totals[int(observation.node_count)] += weight_value + + positive_sizes = sorted( + size for size, total in observed_size_totals.items() if total > 0.0 + ) + if not positive_sizes: + return float(np.min(positive_observed)) + + last_observed_size = positive_sizes[-1] + last_observed_total = float(observed_size_totals[last_observed_size]) + if last_observed_total <= 0.0: + return float(np.min(positive_observed)) + + step_ratios: list[float] = [] + for previous_size, current_size in zip( + positive_sizes, + positive_sizes[1:], + strict=False, + ): + previous_total = float(observed_size_totals[previous_size]) + current_total = float(observed_size_totals[current_size]) + if previous_total <= 0.0 or current_total <= 0.0: + continue + gap = max(int(current_size - previous_size), 1) + step_ratios.append((current_total / previous_total) ** (1.0 / gap)) + decay_per_node = ( + float(np.median(np.asarray(step_ratios, dtype=float))) + if step_ratios + else 0.5 + ) + decay_per_node = float(np.clip(decay_per_node, 0.05, 0.9)) + + target_sizes = sorted( + { + int(prediction.target_node_count) + for prediction in predictions + if int(prediction.target_node_count) > 0 + } + ) + if not target_sizes: + return min(last_observed_total, float(np.min(positive_observed))) + + extrapolated_total = 0.0 + for target_size in target_sizes: + if target_size <= last_observed_size: + extrapolated_total += float( + observed_size_totals.get(target_size, last_observed_total) + ) + continue + extrapolated_total += float( + last_observed_total + * (decay_per_node ** (target_size - last_observed_size)) + ) + return float( + min( + extrapolated_total, + last_observed_total, + ) + ) + + +def _prediction_share_weight( + prediction: PredictedClusterCandidate, + *, + frame_timestep_fs: float, +) -> float: + saxs_weight = _saxs_component_weight( + prediction.predicted_mean_count_per_frame, + prediction.predicted_occupancy_fraction, + ) + if saxs_weight > 0.0: + return float(saxs_weight) + + occupancy_weight = max(float(prediction.predicted_occupancy_fraction), 0.0) + if occupancy_weight > 0.0: + return occupancy_weight + + timestep = max(float(frame_timestep_fs), 1e-12) + return max(float(prediction.predicted_mean_lifetime_fs), 0.0) / timestep + + +def _safe_component_stem(label: str) -> str: + sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", str(label).strip()) + return sanitized.strip("._") or "component" + + +def _write_component_profile( + output_path: Path, + *, + q_values: np.ndarray, + intensity: np.ndarray, + source: str, +) -> None: + q_values = np.asarray(q_values, dtype=float) + intensity = np.asarray(intensity, dtype=float) + if q_values.shape != intensity.shape: + raise ValueError("SAXS component q and intensity arrays must match.") + output_path.parent.mkdir(parents=True, exist_ok=True) + header = ( + f"# Source: {source}\n" "# Columns: q, S(q)_avg, S(q)_std, S(q)_se\n" + ) + data = np.column_stack( + [ + q_values, + intensity, + np.zeros_like(intensity, dtype=float), + np.zeros_like(intensity, dtype=float), + ] + ) + np.savetxt( + output_path, + data, + comments="", + header=header, + fmt=["%.8f", "%.8f", "%.8f", "%.8f"], + ) + + +def _write_xyz_structure( + output_path: Path, + *, + label: str, + elements: tuple[str, ...] | list[str], + coordinates: np.ndarray, + comment: str, +) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + coords = np.asarray(coordinates, dtype=float) + lines = [str(len(elements)), f"label={label} {comment}".strip()] + for element, position in zip(elements, coords, strict=False): + lines.append( + f"{element} {float(position[0]):.8f} {float(position[1]):.8f} " + f"{float(position[2]):.8f}" + ) + output_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def _predicted_structure_file_stem( + prediction: PredictedClusterCandidate, +) -> str: + return ( + f"{prediction.target_node_count:02d}_rank{prediction.rank:02d}_" + f"{_safe_component_stem(prediction.label)}" + ) + + +def _write_predicted_structure_file( + prediction: PredictedClusterCandidate, + *, + surrogate_structure_dir: Path, +) -> Path: + structure_path = ( + surrogate_structure_dir + / f"{_predicted_structure_file_stem(prediction)}.xyz" + ) + _write_xyz_structure( + structure_path, + label=prediction.label, + elements=prediction.generated_elements, + coordinates=prediction.generated_coordinates, + comment=( + f"target_node_count={prediction.target_node_count} " + f"rank={prediction.rank} source_label=" + f"{'' if prediction.source_label is None else prediction.source_label}" + ), + ) + return structure_path + + +def _fit_property_model( + feature_matrix: np.ndarray, + targets: np.ndarray, + *, + weights: np.ndarray, + default_value: float, + transform: str = "identity", + lower_bound: float | None = None, + upper_bound: float | None = None, +) -> _PropertyModel: + x_values = np.asarray(feature_matrix, dtype=float) + y_values = np.asarray(targets, dtype=float) + weight_values = np.asarray(weights, dtype=float) + if y_values.size == 0: + return _PropertyModel( + coefficients=None, + constant_value=None, + transform=transform, + default_value=default_value, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + transformed_targets = ( + np.log1p(np.clip(y_values, 0.0, None)) + if transform == "log1p" + else y_values + ) + if transformed_targets.size == 1: + return _PropertyModel( + coefficients=None, + constant_value=float(transformed_targets[0]), + transform=transform, + default_value=default_value, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + clipped_weights = np.clip(weight_values, 1e-9, None) + sqrt_weights = np.sqrt(clipped_weights) + weighted_x = x_values * sqrt_weights[:, None] + weighted_y = transformed_targets * sqrt_weights + gram = weighted_x.T @ weighted_x + ridge = np.eye(gram.shape[0], dtype=float) * _RIDGE_REGULARIZATION + try: + coefficients = np.linalg.solve( + gram + ridge, + weighted_x.T @ weighted_y, + ) + except np.linalg.LinAlgError: + coefficients = np.linalg.pinv(gram + ridge) @ ( + weighted_x.T @ weighted_y + ) + return _PropertyModel( + coefficients=np.asarray(coefficients, dtype=float), + constant_value=None, + transform=transform, + default_value=default_value, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + + +def _filtered_structure_counts( + elements: list[str] | tuple[str, ...], + *, + tracked_elements: set[str], +) -> dict[str, int]: + counts: Counter[str] = Counter() + for raw_element in elements: + normalized = _normalized_element_symbol(raw_element) + if not normalized: + continue + if tracked_elements and normalized not in tracked_elements: + continue + counts[normalized] += 1 + return _normalized_counts(dict(counts)) + + +def _select_representative_structure_counts( + count_rows: list[tuple[Path, dict[str, int]]], + *, + fallback_path: Path | None, + fallback_counts: dict[str, int], +) -> tuple[Path | None, dict[str, int]]: + if not count_rows: + return fallback_path, _normalized_counts(fallback_counts) + + signature_counts: Counter[tuple[tuple[str, int], ...]] = Counter( + tuple(sorted(counts.items())) for _path, counts in count_rows if counts + ) + if not signature_counts: + return fallback_path, _normalized_counts(fallback_counts) + + best_signature = min( + signature_counts, + key=lambda signature: ( + -signature_counts[signature], + -sum(count for _element, count in signature), + signature, + ), + ) + matching_rows = [ + (path, counts) + for path, counts in count_rows + if tuple(sorted(counts.items())) == best_signature + ] + representative_path, representative_counts = min( + matching_rows, + key=lambda item: ( + -sum(item[1].values()), + str(item[0]), + ), + ) + return representative_path, _normalized_counts(representative_counts) + + +def _parse_stoichiometry_label(label: str) -> dict[str, int]: + counts: dict[str, int] = {} + for element, count_text in _STOICHIOMETRY_TOKEN_PATTERN.findall( + str(label) + ): + normalized = _normalized_element_symbol(element) + if not normalized: + continue + counts[normalized] = counts.get(normalized, 0) + ( + int(count_text) if count_text else 1 + ) + return counts + + +def _normalized_counts(counts: dict[str, int]) -> dict[str, int]: + return { + _normalized_element_symbol(element): int(count) + for element, count in counts.items() + if int(count) > 0 and _normalized_element_symbol(element) + } + + +def _normalized_element_symbol(raw_value: str) -> str: + text = str(raw_value).strip() + if not text: + return "" + return text[:1].upper() + text[1:].lower() + + +def _candidate_feature_vector( + counts: dict[str, int], + *, + node_elements: tuple[str, ...], + non_node_elements: tuple[str, ...], +) -> np.ndarray: + normalized = _normalized_counts(counts) + node_count = max( + float(sum(normalized.get(element, 0) for element in node_elements)), + 1.0, + ) + total_atoms = float(sum(normalized.values())) + ratios = [ + float(normalized.get(element, 0)) / node_count + for element in non_node_elements + ] + return np.asarray([1.0, node_count, total_atoms, *ratios], dtype=float) + + +def _weighted_node_element_fractions( + observations: list[ClusterDynamicsMLTrainingObservation], + *, + node_elements: tuple[str, ...], +) -> dict[str, float]: + if not node_elements: + return {} + totals = {element: 0.0 for element in node_elements} + total_weight = 0.0 + for row in observations: + node_count = max(row.node_count, 0) + if node_count <= 0: + continue + total_weight += row.stability_weight + for element in node_elements: + totals[element] += ( + row.stability_weight + * row.element_counts.get(element, 0) + / node_count + ) + if total_weight <= 0.0: + uniform = 1.0 / float(len(node_elements)) + return {element: uniform for element in node_elements} + fractions = { + element: totals[element] / total_weight for element in node_elements + } + fraction_sum = max(sum(fractions.values()), 1e-12) + return { + element: value / fraction_sum for element, value in fractions.items() + } + + +def _allocate_node_counts( + target_node_count: int, + node_element_fractions: dict[str, float], +) -> dict[str, int]: + if not node_element_fractions: + return {} + exact = { + element: float(target_node_count) * float(fraction) + for element, fraction in node_element_fractions.items() + } + allocated = { + element: int(math.floor(value)) for element, value in exact.items() + } + remainder = int(target_node_count - sum(allocated.values())) + ranked_remainders = sorted( + exact.items(), + key=lambda item: (-(item[1] - math.floor(item[1])), item[0]), + ) + for index in range(remainder): + element = ranked_remainders[index % len(ranked_remainders)][0] + allocated[element] += 1 + return { + element: count for element, count in allocated.items() if count > 0 + } + + +def _candidate_has_support( + observations: list[ClusterDynamicsMLTrainingObservation], + *, + counts: dict[str, int], + target_node_count: int, + node_elements: tuple[str, ...], + non_node_elements: tuple[str, ...], +) -> bool: + del node_elements + normalized = _normalized_counts(counts) + if any(normalized.get(element, 0) > 0 for element in non_node_elements): + return True + + supported_pure_node_sizes = sorted( + { + int(row.node_count) + for row in observations + if row.node_count > 0 + and not any( + int(row.element_counts.get(element, 0)) > 0 + for element in non_node_elements + ) + } + ) + if not supported_pure_node_sizes: + return False + return int(target_node_count) <= max(supported_pure_node_sizes) + 1 + + +def _required_non_node_count_floors( + observations: list[ClusterDynamicsMLTrainingObservation], + *, + target_node_count: int, + non_node_elements: tuple[str, ...], + atom_type_by_element: dict[str, str], +) -> dict[str, int]: + if int(target_node_count) <= 1: + return {} + multi_node_observations = [ + row + for row in observations + if int(row.node_count) >= 2 and int(row.node_count) > 0 + ] + if not multi_node_observations: + return {} + + required_counts: dict[str, int] = {} + for element in non_node_elements: + if atom_type_by_element.get(element) != "linker": + continue + ratios: list[float] = [] + for row in multi_node_observations: + count = int(row.element_counts.get(element, 0)) + if count <= 0: + ratios = [] + break + ratios.append(float(count) / float(row.node_count)) + if not ratios: + continue + minimum_ratio = min(ratios) + required_counts[element] = max( + 1, + int(math.ceil(float(target_node_count) * minimum_ratio - 1e-9)), + ) + return required_counts + + +def _composition_distance( + counts_a: dict[str, int], + counts_b: dict[str, int], + *, + node_count: int, +) -> float: + elements = sorted(set(counts_a) | set(counts_b)) + denominator = max(float(node_count), 1.0) + return float( + sum( + abs(counts_a.get(element, 0) - counts_b.get(element, 0)) + for element in elements + ) + / denominator + ) + + +def _structure_descriptor_row( + coordinates: np.ndarray, +) -> tuple[float, float, float, float, float, float]: + coords = np.asarray(coordinates, dtype=float) + if coords.size == 0: + return (0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + centered = coords - np.mean(coords, axis=0, keepdims=True) + radial = np.linalg.norm(centered, axis=1) + rg = float(np.sqrt(np.mean(np.sum(centered**2, axis=1)))) + max_radius = float(np.max(radial)) + covariance = np.cov(centered.T, bias=True) + eigvals = np.linalg.eigvalsh(covariance) + semiaxes = np.sqrt(np.clip(np.sort(eigvals)[::-1], 0.0, None)) + if semiaxes.size < 3: + semiaxes = np.pad(semiaxes, (0, 3 - semiaxes.size)) + return ( + float(len(coords)), + rg, + max_radius, + float(semiaxes[0]), + float(semiaxes[1]), + float(semiaxes[2]), + ) + + +def _empty_lifetime_summary( + label: str, + *, + cluster_size: int, +) -> ClusterLifetimeSummary: + return ClusterLifetimeSummary( + label=label, + cluster_size=int(cluster_size), + total_observations=0, + occupied_frames=0, + mean_count_per_frame=0.0, + occupancy_fraction=0.0, + association_events=0, + dissociation_events=0, + association_rate_per_ps=0.0, + dissociation_rate_per_ps=0.0, + completed_lifetime_count=0, + window_truncated_lifetime_count=0, + mean_lifetime_fs=None, + std_lifetime_fs=None, + ) + + +def _geometry_atom_type_label(atom_type: str) -> str: + normalized = str(atom_type).strip().lower() + if normalized == "node": + return "node" + if normalized == "linker": + return "linker" + return "shell" + + +def _structure_files_in_directory(directory: Path) -> list[Path]: + return sorted( + path + for path in directory.iterdir() + if path.is_file() and path.suffix.lower() in {".xyz", ".pdb"} + ) + + +def _sorted_pair_key(value_a: str, value_b: str) -> tuple[str, str]: + return tuple(sorted((str(value_a), str(value_b)))) + + +def _median_or_default( + values: list[float] | np.ndarray | tuple[float, ...], + *, + default: float, +) -> float: + if len(values) == 0: + return float(default) + return float(np.median(np.asarray(values, dtype=float))) + + +def _pair_cutoff_distance( + element_a: str, + element_b: str, + *, + pair_cutoff_definitions: PairCutoffDefinitions, +) -> float | None: + normalized_a = _normalized_element_symbol(element_a) + normalized_b = _normalized_element_symbol(element_b) + for (atom1, atom2), level_map in pair_cutoff_definitions.items(): + if { + _normalized_element_symbol(atom1), + _normalized_element_symbol(atom2), + } != {normalized_a, normalized_b}: + continue + if 0 in level_map: + return float(level_map[0]) + if level_map: + return float(min(level_map.values())) + return None + + +def _fallback_pair_distance( + element_a: str, + element_b: str, + *, + pair_cutoff_definitions: PairCutoffDefinitions, + default: float, +) -> float: + pair_cutoff = _pair_cutoff_distance( + element_a, + element_b, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + if pair_cutoff is not None: + return float(pair_cutoff) + return float(default) + + +def _pair_contact_distance( + element_a: str, + element_b: str, + *, + geometry_statistics: _TrainingGeometryStatistics, + pair_cutoff_definitions: PairCutoffDefinitions, +) -> float: + pair_key = _sorted_pair_key(element_a, element_b) + if pair_key in geometry_statistics.contact_distance_medians: + return float(geometry_statistics.contact_distance_medians[pair_key]) + if pair_key in geometry_statistics.bond_length_medians: + return float(geometry_statistics.bond_length_medians[pair_key]) + geometry_pair_key = _sorted_pair_key( + geometry_statistics.atom_type_by_element.get(element_a, "shell"), + geometry_statistics.atom_type_by_element.get(element_b, "shell"), + ) + if ( + geometry_pair_key + in geometry_statistics.geometry_contact_distance_medians + ): + return float( + geometry_statistics.geometry_contact_distance_medians[ + geometry_pair_key + ] + ) + return float( + _fallback_pair_distance( + element_a, + element_b, + pair_cutoff_definitions=pair_cutoff_definitions, + default=geometry_statistics.node_bond_length, + ) + ) + + +def _contact_distance_cutoff( + element_a: str, + element_b: str, + *, + preliminary_contact_medians: dict[tuple[str, str], float] | None = None, + geometry_statistics: _TrainingGeometryStatistics | None = None, + default_distance: float, + pair_cutoff_definitions: PairCutoffDefinitions, +) -> float: + pair_key = _sorted_pair_key(element_a, element_b) + if geometry_statistics is not None: + target_distance = _pair_contact_distance( + element_a, + element_b, + geometry_statistics=geometry_statistics, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + elif ( + preliminary_contact_medians is not None + and pair_key in preliminary_contact_medians + ): + target_distance = float(preliminary_contact_medians[pair_key]) + else: + target_distance = float( + _fallback_pair_distance( + element_a, + element_b, + pair_cutoff_definitions=pair_cutoff_definitions, + default=default_distance, + ) + ) + fallback_limit = float( + _fallback_pair_distance( + element_a, + element_b, + pair_cutoff_definitions=pair_cutoff_definitions, + default=default_distance, + ) + ) + return float( + min(target_distance * 1.25 + 0.15, fallback_limit * 1.40 + 0.20) + ) + + +def _fallback_node_bond_length( + *, + node_elements: tuple[str, ...], + pair_cutoff_definitions: PairCutoffDefinitions, +) -> float: + explicit_node_lengths = [ + _pair_cutoff_distance( + element_a, + element_b, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + for element_a, element_b in combinations(node_elements, 2) + ] + explicit_node_lengths.extend( + _pair_cutoff_distance( + element, + element, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + for element in node_elements + ) + explicit_node_lengths = [ + float(value) for value in explicit_node_lengths if value is not None + ] + if explicit_node_lengths: + return float(np.median(np.asarray(explicit_node_lengths, dtype=float))) + bridging_distances = [ + float(level_value) * 2.0 + for (atom1, atom2), level_map in pair_cutoff_definitions.items() + if ( + _normalized_element_symbol(atom1) in node_elements + and _normalized_element_symbol(atom2) not in node_elements + ) + or ( + _normalized_element_symbol(atom2) in node_elements + and _normalized_element_symbol(atom1) not in node_elements + ) + for level_value in level_map.values() + ] + if bridging_distances: + return float(np.median(np.asarray(bridging_distances, dtype=float))) + return 3.0 + + +def _angle_between_vectors( + vector_a: np.ndarray, + vector_b: np.ndarray, +) -> float | None: + norm_a = float(np.linalg.norm(vector_a)) + norm_b = float(np.linalg.norm(vector_b)) + if norm_a <= 1e-12 or norm_b <= 1e-12: + return None + cosine = float(np.dot(vector_a, vector_b) / (norm_a * norm_b)) + return float(math.degrees(math.acos(np.clip(cosine, -1.0, 1.0)))) + + +def _safe_unit_vector( + vector: np.ndarray, + *, + fallback: np.ndarray | None = None, +) -> np.ndarray: + candidate = np.asarray(vector, dtype=float) + norm = float(np.linalg.norm(candidate)) + if norm > 1e-12: + return candidate / norm + if fallback is None: + return np.asarray([1.0, 0.0, 0.0], dtype=float) + return _safe_unit_vector(np.asarray(fallback, dtype=float)) + + +def _orthonormal_basis(axis: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + unit_axis = _safe_unit_vector(axis) + if abs(float(unit_axis[0])) < 0.9: + reference = np.asarray([1.0, 0.0, 0.0], dtype=float) + else: + reference = np.asarray([0.0, 1.0, 0.0], dtype=float) + basis_u = np.cross(unit_axis, reference) + basis_u = _safe_unit_vector(basis_u, fallback=np.asarray([0.0, 1.0, 0.0])) + basis_v = np.cross(unit_axis, basis_u) + basis_v = _safe_unit_vector(basis_v, fallback=np.asarray([0.0, 0.0, 1.0])) + return basis_u, basis_v + + +def _cone_directions( + axis: np.ndarray, + *, + angle_degrees: float, + samples: int = 8, +) -> list[np.ndarray]: + unit_axis = _safe_unit_vector(axis) + theta = math.radians(float(angle_degrees)) + if theta <= 1e-6: + return [unit_axis] + if abs(theta - math.pi) <= 1e-6: + return [unit_axis * -1.0] + basis_u, basis_v = _orthonormal_basis(unit_axis) + directions: list[np.ndarray] = [] + for sample_index in range(max(int(samples), 1)): + phi = (2.0 * math.pi * float(sample_index)) / float(max(samples, 1)) + radial = math.cos(phi) * basis_u + math.sin(phi) * basis_v + directions.append( + _safe_unit_vector( + math.cos(theta) * unit_axis + math.sin(theta) * radial + ) + ) + return directions + + +def _direction_basis(reference: np.ndarray) -> list[np.ndarray]: + axis = _safe_unit_vector(reference) + basis_u, basis_v = _orthonormal_basis(axis) + candidates = [ + axis, + axis * -1.0, + basis_u, + basis_u * -1.0, + basis_v, + basis_v * -1.0, + axis + basis_u, + axis - basis_u, + axis + basis_v, + axis - basis_v, + basis_u + basis_v, + basis_u - basis_v, + ] + unique: list[np.ndarray] = [] + for candidate in candidates: + direction = _safe_unit_vector(candidate) + if any(np.allclose(direction, other) for other in unique): + continue + unique.append(direction) + return unique + + +def _minimum_spanning_tree_edges( + coordinates: np.ndarray, +) -> list[tuple[int, int]]: + coords = np.asarray(coordinates, dtype=float) + point_count = len(coords) + if point_count < 2: + return [] + parent = list(range(point_count)) + + def find(index: int) -> int: + while parent[index] != index: + parent[index] = parent[parent[index]] + index = parent[index] + return index + + def union(index_a: int, index_b: int) -> bool: + root_a = find(index_a) + root_b = find(index_b) + if root_a == root_b: + return False + parent[root_b] = root_a + return True + + all_edges = sorted( + ( + float(np.linalg.norm(coords[index_a] - coords[index_b])), + index_a, + index_b, + ) + for index_a, index_b in combinations(range(point_count), 2) + ) + selected: list[tuple[int, int]] = [] + for _distance, index_a, index_b in all_edges: + if not union(index_a, index_b): + continue + selected.append((index_a, index_b)) + if len(selected) == point_count - 1: + break + return selected + + +def _node_scaffold_edges( + coordinates: np.ndarray, + node_elements: list[str] | tuple[str, ...], + *, + pair_cutoff_definitions: PairCutoffDefinitions, +) -> list[tuple[int, int]]: + coords = np.asarray(coordinates, dtype=float) + point_count = len(coords) + if point_count < 2: + return [] + explicit_edges: set[tuple[int, int]] = set() + all_edges = sorted( + ( + float(np.linalg.norm(coords[index_a] - coords[index_b])), + index_a, + index_b, + ) + for index_a, index_b in combinations(range(point_count), 2) + ) + for distance, index_a, index_b in all_edges: + cutoff = _pair_cutoff_distance( + node_elements[index_a], + node_elements[index_b], + pair_cutoff_definitions=pair_cutoff_definitions, + ) + if cutoff is None: + continue + if distance <= float(cutoff) * 1.15: + explicit_edges.add((min(index_a, index_b), max(index_a, index_b))) + if not explicit_edges: + return _minimum_spanning_tree_edges(coords) + + parent = list(range(point_count)) + + def find(index: int) -> int: + while parent[index] != index: + parent[index] = parent[parent[index]] + index = parent[index] + return index + + def union(index_a: int, index_b: int) -> bool: + root_a = find(index_a) + root_b = find(index_b) + if root_a == root_b: + return False + parent[root_b] = root_a + return True + + for index_a, index_b in explicit_edges: + union(index_a, index_b) + for _distance, index_a, index_b in all_edges: + if len({find(index) for index in range(point_count)}) == 1: + break + if union(index_a, index_b): + explicit_edges.add((min(index_a, index_b), max(index_a, index_b))) + return sorted(explicit_edges) + + +def _associate_non_node_atoms_to_nodes( + coordinates: np.ndarray, + *, + elements: list[str], + node_indices: list[int], + pair_cutoff_definitions: PairCutoffDefinitions, +) -> dict[int, tuple[int, ...]]: + coords = np.asarray(coordinates, dtype=float) + if not node_indices: + return {} + node_index_set = set(node_indices) + associations: dict[int, tuple[int, ...]] = {} + for atom_index, element in enumerate(elements): + if atom_index in node_index_set: + continue + distances = [ + ( + float(np.linalg.norm(coords[atom_index] - coords[node_index])), + node_index, + ) + for node_index in node_indices + ] + explicit_matches = [ + (distance, node_index) + for distance, node_index in distances + if ( + cutoff := _pair_cutoff_distance( + element, + elements[node_index], + pair_cutoff_definitions=pair_cutoff_definitions, + ) + ) + is not None + and distance <= float(cutoff) * 1.15 + ] + if explicit_matches: + associations[atom_index] = tuple( + node_index + for _distance, node_index in sorted( + explicit_matches, + key=lambda item: (item[0], item[1]), + ) + ) + continue + nearest_node = min(distances, key=lambda item: (item[0], item[1]))[1] + associations[atom_index] = (nearest_node,) + return associations + + +def _adjacency_from_edges( + point_count: int, + edges: list[tuple[int, int]], +) -> dict[int, set[int]]: + adjacency = {index: set() for index in range(int(point_count))} + for index_a, index_b in edges: + adjacency.setdefault(index_a, set()).add(index_b) + adjacency.setdefault(index_b, set()).add(index_a) + return adjacency + + +def _unique_edges_from_adjacency( + adjacency: dict[int, set[int]], +) -> list[tuple[int, int]]: + return sorted( + { + (min(index_a, index_b), max(index_a, index_b)) + for index_a, neighbors in adjacency.items() + for index_b in neighbors + } + ) + + +def _next_pending_node_element( + remaining_node_counts: Counter[str], +) -> str | None: + for element in sorted(remaining_node_counts): + if int(remaining_node_counts[element]) > 0: + return element + return None + + +def _node_bond_length_for_element( + element: str, + *, + node_elements: tuple[str, ...], + geometry_statistics: _TrainingGeometryStatistics, + pair_cutoff_definitions: PairCutoffDefinitions, +) -> float: + if element in node_elements: + return float(geometry_statistics.node_bond_length) + pair_distances = [ + geometry_statistics.bond_length_medians[ + _sorted_pair_key(element, node_element) + ] + for node_element in node_elements + if _sorted_pair_key(element, node_element) + in geometry_statistics.bond_length_medians + ] + if pair_distances: + return float(np.median(np.asarray(pair_distances, dtype=float))) + return float( + _default_element_distance( + element, + pair_cutoff_definitions=pair_cutoff_definitions, + node_elements=set(node_elements), + ) + ) + + +def _node_growth_position( + node_positions: list[np.ndarray], + adjacency: dict[int, set[int]], + *, + geometry_statistics: _TrainingGeometryStatistics, +) -> tuple[int, np.ndarray]: + if not node_positions: + return 0, np.zeros(3, dtype=float) + coords = np.asarray(node_positions, dtype=float) + if len(coords) == 1: + bond_length = max(float(geometry_statistics.node_bond_length), 0.1) + return 0, np.asarray(coords[0] + np.asarray([bond_length, 0.0, 0.0])) + centroid = np.mean(coords, axis=0) + desired_angle = float( + geometry_statistics.node_angle_medians.get(("node", "node"), 180.0) + ) + bond_length = max(float(geometry_statistics.node_bond_length), 0.1) + anchor_indices = sorted( + range(len(node_positions)), + key=lambda index: ( + len(adjacency.get(index, set())), + -float(np.linalg.norm(coords[index] - centroid)), + index, + ), + ) + best_anchor = 0 + best_position = np.asarray(coords[0] + np.asarray([bond_length, 0.0, 0.0])) + best_score: float | None = None + for anchor_index in anchor_indices: + anchor = coords[anchor_index] + neighbor_indices = sorted(adjacency.get(anchor_index, set())) + outward = anchor - centroid + if neighbor_indices: + outward = outward + np.sum( + anchor - coords[neighbor_indices], + axis=0, + ) + candidate_directions = _direction_basis(outward) + if len(neighbor_indices) == 1: + inward = coords[neighbor_indices[0]] - anchor + candidate_directions.extend( + _cone_directions( + inward, + angle_degrees=desired_angle, + samples=12, + ) + ) + unique_directions: list[np.ndarray] = [] + for direction in candidate_directions: + unit_direction = _safe_unit_vector(direction, fallback=outward) + if any( + np.allclose(unit_direction, other) + for other in unique_directions + ): + continue + unique_directions.append(unit_direction) + for direction in unique_directions: + candidate = anchor + direction * bond_length + anchor_distance = float(np.linalg.norm(candidate - anchor)) + collision_penalty = 0.0 + for existing_index, point in enumerate(coords): + distance = float(np.linalg.norm(candidate - point)) + if existing_index == anchor_index: + collision_penalty += abs(distance - bond_length) * 5.0 + continue + if distance < bond_length * 0.75: + collision_penalty += ( + (bond_length * 0.75) - distance + ) ** 2 * 500.0 + elif distance < bond_length * 0.95: + collision_penalty += ( + (bond_length * 0.95) - distance + ) ** 2 * 60.0 + angle_penalty = 0.0 + for neighbor_index in neighbor_indices: + angle = _angle_between_vectors( + candidate - anchor, + coords[neighbor_index] - anchor, + ) + if angle is not None: + angle_penalty += abs(angle - desired_angle) + outward_alignment = float( + np.dot( + _safe_unit_vector(candidate - anchor), + _safe_unit_vector(outward, fallback=candidate - anchor), + ) + ) + score = ( + collision_penalty + + angle_penalty + + abs(anchor_distance - bond_length) * 5.0 + - outward_alignment * 8.0 + ) + if best_score is None or score < best_score: + best_score = float(score) + best_anchor = int(anchor_index) + best_position = np.asarray(candidate, dtype=float) + return best_anchor, best_position + + +def _count_neighbor_type( + neighbor_entries: list[tuple[str, np.ndarray]], + geometry_type: str, +) -> int: + return sum( + 1 + for entry_type, _vector in neighbor_entries + if entry_type == geometry_type + ) + + +def _placement_sequence( + target_counts: dict[str, int], + *, + node_elements: tuple[str, ...], + geometry_statistics: _TrainingGeometryStatistics, +) -> list[str]: + queued: list[tuple[float, int, str]] = [] + for element, count in sorted(target_counts.items()): + if element in node_elements or int(count) <= 0: + continue + geometry_type = geometry_statistics.atom_type_by_element.get( + element, "shell" + ) + bridge_degree = float( + geometry_statistics.non_node_node_coordination_medians.get( + element, 1.0 + ) + ) + geometry_priority = 0 if geometry_type == "linker" else 1 + for _ in range(int(count)): + queued.append((-bridge_degree, geometry_priority, element)) + queued.sort() + return [element for _bridge_degree, _geometry_priority, element in queued] + + +def _select_bridge_edge( + edges: list[tuple[int, int]], + *, + node_positions: list[np.ndarray], + node_neighbor_entries: dict[int, list[tuple[str, np.ndarray]]], + geometry_type: str, + edge_assignments: Counter[tuple[int, int]], + geometry_statistics: _TrainingGeometryStatistics, +) -> tuple[int, int]: + coords = np.asarray(node_positions, dtype=float) + centroid = ( + np.mean(coords, axis=0) + if len(coords) > 0 + else np.zeros(3, dtype=float) + ) + target_coordination = float( + geometry_statistics.node_coordination_medians.get(geometry_type, 0.0) + ) + return max( + edges, + key=lambda edge: ( + max( + target_coordination + - _count_neighbor_type( + node_neighbor_entries.get(edge[0], []), geometry_type + ), + 0.0, + ) + + max( + target_coordination + - _count_neighbor_type( + node_neighbor_entries.get(edge[1], []), geometry_type + ), + 0.0, + ), + -float(edge_assignments[edge]), + float( + np.linalg.norm( + ((coords[edge[0]] + coords[edge[1]]) * 0.5) - centroid + ) + ), + ), + ) + + +def _bridge_atom_position( + edge: tuple[int, int], + *, + node_positions: list[np.ndarray], + existing_positions: list[np.ndarray], + existing_elements: list[str], + element: str, + bond_length: float, + geometry_statistics: _TrainingGeometryStatistics, + pair_cutoff_definitions: PairCutoffDefinitions, +) -> np.ndarray: + coords = np.asarray(node_positions, dtype=float) + point_a = coords[edge[0]] + point_b = coords[edge[1]] + midpoint = (point_a + point_b) * 0.5 + edge_vector = point_b - point_a + edge_length = float(np.linalg.norm(edge_vector)) + if edge_length <= 1e-12: + return np.asarray(midpoint, dtype=float) + desired_distance = max(float(bond_length), edge_length * 0.5) + height = math.sqrt( + max((desired_distance**2) - ((edge_length * 0.5) ** 2), 0.0) + ) + if height <= 1e-8: + return np.asarray(midpoint, dtype=float) + centroid = ( + np.mean(np.asarray(node_positions, dtype=float), axis=0) + if node_positions + else np.zeros(3, dtype=float) + ) + radial = midpoint - centroid + edge_unit = _safe_unit_vector(edge_vector) + radial_perpendicular = radial - np.dot(radial, edge_unit) * edge_unit + basis_u, basis_v = _orthonormal_basis(edge_unit) + preferred_normal = _safe_unit_vector( + radial_perpendicular, fallback=basis_u + ) + candidates = [ + midpoint + preferred_normal * height, + midpoint - preferred_normal * height, + midpoint + basis_u * height, + midpoint - basis_u * height, + midpoint + basis_v * height, + midpoint - basis_v * height, + ] + existing = np.asarray(existing_positions, dtype=float) + best_position = np.asarray(candidates[0], dtype=float) + best_score: float | None = None + for candidate in candidates: + distances = [ + float(np.linalg.norm(candidate - position)) + for position in existing + if float(np.linalg.norm(candidate - position)) > 1e-8 + ] + minimum_distance = min(distances) if distances else math.inf + interaction_penalty = _non_node_interaction_penalty( + element, + candidate, + existing_elements=existing_elements, + existing_positions=existing_positions, + geometry_statistics=geometry_statistics, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + node_distance_penalty = ( + abs(float(np.linalg.norm(candidate - point_a)) - bond_length) * 4.0 + + abs(float(np.linalg.norm(candidate - point_b)) - bond_length) + * 4.0 + ) + score = ( + node_distance_penalty + + interaction_penalty + - minimum_distance * 3.0 + ) + if best_score is None or score < best_score: + best_score = float(score) + best_position = np.asarray(candidate, dtype=float) + return best_position + + +def _select_attachment_node( + *, + node_positions: list[np.ndarray], + node_neighbor_entries: dict[int, list[tuple[str, np.ndarray]]], + geometry_type: str, + geometry_statistics: _TrainingGeometryStatistics, +) -> int: + coords = np.asarray(node_positions, dtype=float) + centroid = ( + np.mean(coords, axis=0) + if len(coords) > 0 + else np.zeros(3, dtype=float) + ) + target_coordination = float( + geometry_statistics.node_coordination_medians.get(geometry_type, 1.0) + ) + return max( + range(len(node_positions)), + key=lambda index: ( + max( + target_coordination + - _count_neighbor_type( + node_neighbor_entries.get(index, []), geometry_type + ), + 0.0, + ), + float(np.linalg.norm(coords[index] - centroid)), + -len( + [ + 1 + for entry_type, _vector in node_neighbor_entries.get( + index, [] + ) + if entry_type != "node" + ] + ), + -len( + [ + 1 + for entry_type, _vector in node_neighbor_entries.get( + index, [] + ) + if entry_type == "node" + ] + ), + ), + ) + + +def _current_contact_counts( + *, + existing_elements: list[str], + existing_positions: list[np.ndarray], + geometry_statistics: _TrainingGeometryStatistics, + pair_cutoff_definitions: PairCutoffDefinitions, +) -> list[Counter[str]]: + counts = [Counter() for _ in existing_elements] + if len(existing_elements) < 2: + return counts + positions = np.asarray(existing_positions, dtype=float) + for atom_index_a, atom_index_b in combinations( + range(len(existing_elements)), 2 + ): + distance = float( + np.linalg.norm(positions[atom_index_a] - positions[atom_index_b]) + ) + if distance > _contact_distance_cutoff( + existing_elements[atom_index_a], + existing_elements[atom_index_b], + geometry_statistics=geometry_statistics, + default_distance=geometry_statistics.node_bond_length, + pair_cutoff_definitions=pair_cutoff_definitions, + ): + continue + geometry_type_a = geometry_statistics.atom_type_by_element.get( + existing_elements[atom_index_a], + "shell", + ) + geometry_type_b = geometry_statistics.atom_type_by_element.get( + existing_elements[atom_index_b], + "shell", + ) + counts[atom_index_a][geometry_type_b] += 1 + counts[atom_index_b][geometry_type_a] += 1 + return counts + + +def _non_node_interaction_penalty( + element: str, + candidate_position: np.ndarray, + *, + existing_elements: list[str], + existing_positions: list[np.ndarray], + geometry_statistics: _TrainingGeometryStatistics, + pair_cutoff_definitions: PairCutoffDefinitions, +) -> float: + geometry_type = geometry_statistics.atom_type_by_element.get( + element, "shell" + ) + tracked_non_node_types = [ + atom_type + for atom_type in geometry_statistics.tracked_atom_types + if atom_type != "node" + ] + if not tracked_non_node_types or not existing_elements: + return 0.0 + current_contact_counts = _current_contact_counts( + existing_elements=existing_elements, + existing_positions=existing_positions, + geometry_statistics=geometry_statistics, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + candidate_counts: Counter[str] = Counter() + candidate_distance_errors: defaultdict[str, list[float]] = defaultdict( + list + ) + reciprocity_bonus = 0.0 + penalty = 0.0 + positions = np.asarray(existing_positions, dtype=float) + for atom_index, existing_element in enumerate(existing_elements): + existing_type = geometry_statistics.atom_type_by_element.get( + existing_element, "shell" + ) + if existing_type == "node": + continue + target_distance = _pair_contact_distance( + element, + existing_element, + geometry_statistics=geometry_statistics, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + distance = float( + np.linalg.norm(candidate_position - positions[atom_index]) + ) + minimum_distance = max(target_distance * 0.65, 0.45) + if distance < minimum_distance: + penalty += ((minimum_distance - distance) ** 2) * 220.0 + cutoff = _contact_distance_cutoff( + element, + existing_element, + geometry_statistics=geometry_statistics, + default_distance=geometry_statistics.node_bond_length, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + if distance > cutoff: + continue + candidate_counts[existing_type] += 1 + candidate_distance_errors[existing_type].append( + abs(distance - target_distance) + ) + desired_existing = int( + round( + geometry_statistics.atom_coordination_medians.get( + (existing_type, geometry_type), + 0.0, + ) + ) + ) + if desired_existing <= 0: + continue + current_count = int( + current_contact_counts[atom_index].get(geometry_type, 0) + ) + reciprocity_bonus += ( + float(max(desired_existing - current_count, 0)) * 6.0 + ) + if current_count >= desired_existing + 1: + penalty += float(current_count - desired_existing) * 6.0 + for neighbor_type in tracked_non_node_types: + available = sum( + 1 + for existing_element in existing_elements + if geometry_statistics.atom_type_by_element.get( + existing_element, "shell" + ) + == neighbor_type + ) + desired_count = int( + round( + geometry_statistics.atom_coordination_medians.get( + (geometry_type, neighbor_type), + 0.0, + ) + ) + ) + if available <= 0 and desired_count <= 0: + continue + target_count = min(max(desired_count, 0), available) + observed_count = int(candidate_counts.get(neighbor_type, 0)) + if target_count > observed_count: + penalty += float(target_count - observed_count) * 12.0 + elif observed_count > max(target_count, 0): + penalty += float(observed_count - max(target_count, 0)) * 10.0 + if target_count > 0 and candidate_distance_errors.get(neighbor_type): + penalty += float( + sum( + sorted(candidate_distance_errors[neighbor_type])[ + :target_count + ] + ) + * 8.0 + ) + elif observed_count > 0 and candidate_distance_errors.get( + neighbor_type + ): + penalty += ( + float(min(candidate_distance_errors[neighbor_type])) * 4.0 + ) + return float(penalty - reciprocity_bonus) + + +def _terminal_atom_position( + anchor_index: int, + *, + node_positions: list[np.ndarray], + node_neighbor_entries: dict[int, list[tuple[str, np.ndarray]]], + existing_positions: list[np.ndarray], + existing_elements: list[str], + element: str, + geometry_type: str, + bond_length: float, + geometry_statistics: _TrainingGeometryStatistics, + pair_cutoff_definitions: PairCutoffDefinitions, +) -> np.ndarray: + coords = np.asarray(node_positions, dtype=float) + anchor = coords[anchor_index] + centroid = ( + np.mean(coords, axis=0) + if len(coords) > 0 + else np.zeros(3, dtype=float) + ) + outward = anchor - centroid + for neighbor_type, vector in node_neighbor_entries.get(anchor_index, []): + if neighbor_type == "node": + outward = outward + (anchor - (anchor + vector)) + candidate_directions = _direction_basis(outward) + desired_node_angle = float( + geometry_statistics.node_angle_medians.get( + _sorted_pair_key("node", geometry_type), + 120.0, + ) + ) + for neighbor_type, vector in node_neighbor_entries.get(anchor_index, []): + neighbor_angle = float( + geometry_statistics.node_angle_medians.get( + _sorted_pair_key(geometry_type, neighbor_type), + desired_node_angle if neighbor_type == "node" else 109.5, + ) + ) + candidate_directions.extend( + _cone_directions( + vector, + angle_degrees=neighbor_angle, + samples=10, + ) + ) + unique_directions: list[np.ndarray] = [] + for direction in candidate_directions: + unit_direction = _safe_unit_vector(direction, fallback=outward) + if any( + np.allclose(unit_direction, other) for other in unique_directions + ): + continue + unique_directions.append(unit_direction) + existing = np.asarray(existing_positions, dtype=float) + best_position = np.asarray( + anchor + unique_directions[0] * bond_length, dtype=float + ) + best_score: float | None = None + for direction in unique_directions: + candidate = anchor + direction * float(bond_length) + collision_penalty = 0.0 + for position in existing: + distance = float(np.linalg.norm(candidate - position)) + if distance <= 1e-8: + continue + if distance < bond_length * 0.7: + collision_penalty += ( + (bond_length * 0.7) - distance + ) ** 2 * 500.0 + elif distance < bond_length * 0.95: + collision_penalty += ( + (bond_length * 0.95) - distance + ) ** 2 * 40.0 + angle_penalty = 0.0 + for neighbor_type, vector in node_neighbor_entries.get( + anchor_index, [] + ): + desired_angle = float( + geometry_statistics.node_angle_medians.get( + _sorted_pair_key(geometry_type, neighbor_type), + 109.5 if geometry_type == neighbor_type else 120.0, + ) + ) + angle = _angle_between_vectors(direction, vector) + if angle is not None: + angle_penalty += abs(angle - desired_angle) + outward_alignment = float( + np.dot( + _safe_unit_vector(direction), + _safe_unit_vector(outward, fallback=direction), + ) + ) + interaction_penalty = _non_node_interaction_penalty( + element, + candidate, + existing_elements=existing_elements, + existing_positions=existing_positions, + geometry_statistics=geometry_statistics, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + score = ( + collision_penalty + + angle_penalty + + interaction_penalty + - outward_alignment * 8.0 + ) + if best_score is None or score < best_score: + best_score = float(score) + best_position = np.asarray(candidate, dtype=float) + return best_position + + +def _build_geometry_guided_structure( + target_counts: dict[str, int], + *, + node_elements: tuple[str, ...], + pair_cutoff_definitions: PairCutoffDefinitions, + geometry_statistics: _TrainingGeometryStatistics, + predicted_max_radius: float, + seed_node_elements: list[str], + seed_node_coordinates: np.ndarray, +) -> tuple[list[str], np.ndarray]: + normalized_target_counts = _normalized_counts(target_counts) + target_node_total = int( + sum( + normalized_target_counts.get(element, 0) + for element in node_elements + ) + ) + if target_node_total <= 0: + return _build_fallback_structure(normalized_target_counts) + + remaining_node_counts: Counter[str] = Counter( + { + element: int(normalized_target_counts.get(element, 0)) + for element in node_elements + } + ) + seed_coords = np.asarray(seed_node_coordinates, dtype=float) + kept_node_elements: list[str] = [] + kept_node_positions: list[np.ndarray] = [] + if seed_coords.ndim == 2 and len(seed_coords) == len(seed_node_elements): + for element, coordinate in zip( + seed_node_elements, seed_coords, strict=False + ): + if len(kept_node_elements) >= target_node_total: + break + if int(remaining_node_counts.get(element, 0)) <= 0: + continue + kept_node_elements.append(str(element)) + kept_node_positions.append(np.asarray(coordinate, dtype=float)) + remaining_node_counts[element] -= 1 + if not kept_node_positions: + first_element = _next_pending_node_element(remaining_node_counts) + if first_element is None: + return _build_fallback_structure(normalized_target_counts) + kept_node_elements.append(first_element) + kept_node_positions.append(np.zeros(3, dtype=float)) + remaining_node_counts[first_element] -= 1 + + adjacency = _adjacency_from_edges( + len(kept_node_positions), + _node_scaffold_edges( + np.asarray(kept_node_positions, dtype=float), + kept_node_elements, + pair_cutoff_definitions=pair_cutoff_definitions, + ), + ) + + while sum(int(value) for value in remaining_node_counts.values()) > 0: + next_element = _next_pending_node_element(remaining_node_counts) + if next_element is None: + break + anchor_index, new_position = _node_growth_position( + kept_node_positions, + adjacency, + geometry_statistics=geometry_statistics, + ) + new_index = len(kept_node_positions) + kept_node_elements.append(next_element) + kept_node_positions.append(np.asarray(new_position, dtype=float)) + adjacency.setdefault(anchor_index, set()).add(new_index) + adjacency.setdefault(new_index, set()).add(anchor_index) + remaining_node_counts[next_element] -= 1 + + all_elements = list(kept_node_elements) + all_positions = [ + np.asarray(position, dtype=float) for position in kept_node_positions + ] + node_neighbor_entries = { + index: [ + ( + "node", + kept_node_positions[neighbor] - kept_node_positions[index], + ) + for neighbor in sorted(adjacency.get(index, set())) + ] + for index in range(len(kept_node_positions)) + } + edge_assignments: Counter[tuple[int, int]] = Counter() + scaffold_edges = _unique_edges_from_adjacency(adjacency) + for element in _placement_sequence( + normalized_target_counts, + node_elements=node_elements, + geometry_statistics=geometry_statistics, + ): + geometry_type = geometry_statistics.atom_type_by_element.get( + element, "shell" + ) + bridge_degree = int( + max( + 1, + min( + round( + geometry_statistics.non_node_node_coordination_medians.get( + element, + 1.0, + ) + ), + 2, + ), + ) + ) + bond_length = max( + _node_bond_length_for_element( + element, + node_elements=node_elements, + geometry_statistics=geometry_statistics, + pair_cutoff_definitions=pair_cutoff_definitions, + ), + 0.1, + ) + if bridge_degree >= 2 and scaffold_edges: + selected_edge = _select_bridge_edge( + scaffold_edges, + node_positions=kept_node_positions, + node_neighbor_entries=node_neighbor_entries, + geometry_type=geometry_type, + edge_assignments=edge_assignments, + geometry_statistics=geometry_statistics, + ) + placed_position = _bridge_atom_position( + selected_edge, + node_positions=kept_node_positions, + existing_positions=all_positions, + existing_elements=all_elements, + element=element, + bond_length=bond_length, + geometry_statistics=geometry_statistics, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + attached_nodes = selected_edge + edge_assignments[selected_edge] += 1 + else: + anchor_index = _select_attachment_node( + node_positions=kept_node_positions, + node_neighbor_entries=node_neighbor_entries, + geometry_type=geometry_type, + geometry_statistics=geometry_statistics, + ) + placed_position = _terminal_atom_position( + anchor_index, + node_positions=kept_node_positions, + node_neighbor_entries=node_neighbor_entries, + existing_positions=all_positions, + existing_elements=all_elements, + element=element, + geometry_type=geometry_type, + bond_length=bond_length, + geometry_statistics=geometry_statistics, + pair_cutoff_definitions=pair_cutoff_definitions, + ) + attached_nodes = (anchor_index,) + all_elements.append(element) + all_positions.append(np.asarray(placed_position, dtype=float)) + for node_index in attached_nodes: + node_neighbor_entries.setdefault(node_index, []).append( + ( + geometry_type, + placed_position - kept_node_positions[node_index], + ) + ) + + coordinates = np.asarray(all_positions, dtype=float) + coordinates = _scale_structure_to_radius( + coordinates, + target_max_radius=max(predicted_max_radius, 0.1), + ) + return all_elements, coordinates + + +def _principal_axis(coordinates: np.ndarray) -> np.ndarray: + coords = np.asarray(coordinates, dtype=float) + if coords.size == 0 or len(coords) == 1: + return np.asarray([1.0, 0.0, 0.0], dtype=float) + centered = coords - np.mean(coords, axis=0, keepdims=True) + covariance = np.cov(centered.T, bias=True) + eigvals, eigvecs = np.linalg.eigh(covariance) + axis = np.asarray(eigvecs[:, np.argmax(eigvals)], dtype=float) + norm = np.linalg.norm(axis) + if norm <= 0.0: + return np.asarray([1.0, 0.0, 0.0], dtype=float) + return axis / norm + + +def _estimate_node_spacing( + coordinates: np.ndarray, + *, + elements: list[str], + node_mask: np.ndarray, + pair_cutoff_definitions: PairCutoffDefinitions, + node_elements: set[str], +) -> float: + del elements + node_coords = np.asarray(coordinates, dtype=float)[ + np.asarray(node_mask, dtype=bool) + ] + if len(node_coords) >= 2: + projections = node_coords @ _principal_axis(node_coords) + ordered = np.sort(projections) + diffs = np.diff(ordered) + positive = diffs[diffs > 1e-6] + if positive.size > 0: + return float(np.median(positive)) + for element_a, element_b in pair_cutoff_definitions: + if element_a in node_elements and element_b in node_elements: + cutoffs = pair_cutoff_definitions[(element_a, element_b)] + if cutoffs: + return float(np.median(list(cutoffs.values()))) + return 3.0 + + +def _terminal_anchor( + coordinates: np.ndarray, + *, + node_mask: np.ndarray, + axis: np.ndarray, +) -> np.ndarray: + coords = np.asarray(coordinates, dtype=float) + mask = np.asarray(node_mask, dtype=bool) + if np.any(mask): + anchors = coords[mask] + else: + anchors = coords + if anchors.size == 0: + return np.zeros(3, dtype=float) + projections = anchors @ np.asarray(axis, dtype=float) + return np.asarray(anchors[int(np.argmax(projections))], dtype=float) + + +def _remove_excess_atoms( + coordinates: np.ndarray, + elements: list[str], + *, + source_counts: Counter[str], + target_counts: dict[str, int], +) -> tuple[np.ndarray, list[str]]: + coords = np.asarray(coordinates, dtype=float) + elem_list = list(elements) + if coords.size == 0: + return coords, elem_list + centered = coords - np.mean(coords, axis=0, keepdims=True) + projections = centered @ _principal_axis(centered) + kept = np.ones(len(elem_list), dtype=bool) + for element, current_count in source_counts.items(): + target_count = int(target_counts.get(element, 0)) + excess = int(current_count - target_count) + if excess <= 0: + continue + indices = [ + index for index, value in enumerate(elem_list) if value == element + ] + ranked = sorted( + indices, key=lambda index: float(projections[index]), reverse=True + ) + for index in ranked[:excess]: + kept[index] = False + return coords[kept], [ + element for element, keep in zip(elem_list, kept) if keep + ] + + +def _element_template_vectors( + coordinates: np.ndarray, + *, + elements: list[str], + element: str, + node_mask: np.ndarray, + axis: np.ndarray, + pair_cutoff_definitions: PairCutoffDefinitions, + node_elements: set[str], +) -> list[np.ndarray]: + coords = np.asarray(coordinates, dtype=float) + nodes = coords[np.asarray(node_mask, dtype=bool)] + if nodes.size == 0: + return [] + source_vectors: list[np.ndarray] = [] + default_distance = _default_element_distance( + element, + pair_cutoff_definitions=pair_cutoff_definitions, + node_elements=node_elements, + ) + for index, source_element in enumerate(elements): + if source_element != element: + continue + atom = coords[index] + deltas = nodes - atom + distances = np.linalg.norm(deltas, axis=1) + if distances.size == 0: + continue + nearest_node = nodes[int(np.argmin(distances))] + vector = atom - nearest_node + norm = np.linalg.norm(vector) + if norm <= 1e-9: + continue + normalized = vector / norm + if float(normalized @ axis) < 0.0: + normalized = normalized * -1.0 + source_vectors.append(normalized * default_distance) + unique_vectors: list[np.ndarray] = [] + for vector in source_vectors: + if not any(np.allclose(vector, other) for other in unique_vectors): + unique_vectors.append(vector) + return unique_vectors + + +def _generic_element_vectors( + *, + count: int, + distance: float, +) -> list[np.ndarray]: + base = [ + np.asarray([0.0, 1.0, 0.0], dtype=float), + np.asarray([0.0, -1.0, 0.0], dtype=float), + np.asarray([0.0, 0.0, 1.0], dtype=float), + np.asarray([0.0, 0.0, -1.0], dtype=float), + np.asarray([0.0, 1.0, 1.0], dtype=float), + np.asarray([0.0, -1.0, 1.0], dtype=float), + np.asarray([0.0, 1.0, -1.0], dtype=float), + np.asarray([0.0, -1.0, -1.0], dtype=float), + ] + vectors: list[np.ndarray] = [] + for index in range(max(count, 1)): + direction = np.asarray(base[index % len(base)], dtype=float) + direction_norm = np.linalg.norm(direction) + if direction_norm <= 0.0: + direction = np.asarray([0.0, 1.0, 0.0], dtype=float) + else: + direction = direction / direction_norm + vectors.append(direction * float(distance)) + return vectors + + +def _default_element_distance( + element: str, + *, + pair_cutoff_definitions: PairCutoffDefinitions, + node_elements: set[str], +) -> float: + distances: list[float] = [] + normalized_element = _normalized_element_symbol(element) + for (atom1, atom2), level_map in pair_cutoff_definitions.items(): + if atom1 in node_elements and atom2 == normalized_element: + distances.extend(float(value) for value in level_map.values()) + if atom2 in node_elements and atom1 == normalized_element: + distances.extend(float(value) for value in level_map.values()) + if distances: + return float(np.median(distances)) + return 2.5 + + +def _scale_structure_to_radius( + coordinates: np.ndarray, + *, + target_max_radius: float, +) -> np.ndarray: + del target_max_radius + return np.asarray(coordinates, dtype=float) + + +def _build_fallback_structure( + target_counts: dict[str, int], +) -> tuple[list[str], np.ndarray]: + elements: list[str] = [] + coords: list[np.ndarray] = [] + index = 0 + for element, count in sorted(target_counts.items()): + for repeat in range(int(count)): + shell = max(index, 1) + angle = float(repeat) * (math.pi / 3.0) + radius = 1.5 + 0.4 * shell + coords.append( + np.asarray( + [ + float(shell), + radius * math.cos(angle), + radius * math.sin(angle), + ], + dtype=float, + ) + ) + elements.append(element) + index += 1 + return elements, np.asarray(coords, dtype=float) diff --git a/tests/test_clusterdynamics.py b/tests/test_clusterdynamics.py new file mode 100644 index 0000000..af4df5f --- /dev/null +++ b/tests/test_clusterdynamics.py @@ -0,0 +1,468 @@ +from __future__ import annotations + +import csv +import os +from pathlib import Path + +import numpy as np +import pytest +from PySide6.QtWidgets import QApplication + +import saxshell.clusterdynamics.cli as clusterdynamics_cli_module +from saxshell import saxshell as saxshell_module +from saxshell.clusterdynamics import ( + ClusterDynamicsWorkflow, + load_cluster_dynamics_dataset, + save_cluster_dynamics_dataset, +) +from saxshell.clusterdynamics.ui.main_window import ClusterDynamicsMainWindow +from saxshell.saxs.project_manager import ( + SAXSProjectManager, + build_project_paths, +) + +ATOM_TYPE_DEFINITIONS = { + "node": [("Pb", None)], + "linker": [("I", None)], + "shell": [("O", None)], +} +PAIR_CUTOFFS = { + ("Pb", "I"): {0: 1.7}, + ("Pb", "O"): {1: 1.3}, +} + + +@pytest.fixture(scope="module") +def qapp(): + os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") + app = QApplication.instance() + if app is None: + app = QApplication([]) + yield app + + +def _connected_xyz_lines() -> str: + return ( + "5\n" + "frame_connected\n" + "Pb 0.0 0.0 0.0\n" + "I 1.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "O 0.2 1.0 0.0\n" + "H 0.2 1.7 0.0\n" + ) + + +def _disconnected_xyz_lines() -> str: + return ( + "5\n" + "frame_disconnected\n" + "Pb 0.0 0.0 0.0\n" + "I 5.0 0.0 0.0\n" + "Pb 10.0 0.0 0.0\n" + "O 0.2 1.0 0.0\n" + "H 0.2 1.7 0.0\n" + ) + + +def _build_frames_dir(tmp_path: Path) -> Path: + frames_dir = tmp_path / "splitxyz0001" + frames_dir.mkdir() + sequence = ( + _disconnected_xyz_lines(), + _connected_xyz_lines(), + _connected_xyz_lines(), + _disconnected_xyz_lines(), + _connected_xyz_lines(), + _disconnected_xyz_lines(), + ) + for index, content in enumerate(sequence): + (frames_dir / f"frame_{index:04d}.xyz").write_text(content) + return frames_dir + + +def _build_offset_frames_dir(tmp_path: Path) -> Path: + frames_dir = tmp_path / "splitxyz_f847fs" + frames_dir.mkdir() + sequence = ( + _connected_xyz_lines(), + _disconnected_xyz_lines(), + ) + for index, content in enumerate(sequence, start=1866): + (frames_dir / f"frame_{index:04d}.xyz").write_text(content) + return frames_dir + + +def _write_energy_file(tmp_path: Path) -> Path: + energy_path = tmp_path / "traj.ener" + lines = [ + f"{step} {step * 10.0:.1f} {1.0 + step:.3f} {300.0 + step:.3f} {-20.0 - step:.3f}\n" + for step in range(6) + ] + energy_path.write_text("".join(lines)) + return energy_path + + +def _presentation_text(presentation) -> str: + texts: list[str] = [] + for slide in presentation.slides: + for shape in slide.shapes: + if getattr(shape, "has_text_frame", False): + texts.append(shape.text) + return "\n".join(texts) + + +def test_cluster_dynamics_workflow_bins_clusters_and_lifetimes(tmp_path): + frames_dir = _build_frames_dir(tmp_path) + energy_path = _write_energy_file(tmp_path) + workflow = ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + shell_levels=(1,), + frame_timestep_fs=10.0, + frames_per_colormap_timestep=3, + energy_file=energy_path, + ) + + preview = workflow.preview_selection() + result = workflow.analyze() + + assert preview.selected_frames == 6 + assert preview.frames_per_colormap_timestep == 3 + assert preview.colormap_timestep_fs == pytest.approx(30.0) + assert preview.bin_count == 2 + assert preview.analysis_start_fs == 0.0 + assert preview.analysis_stop_fs == 60.0 + assert result.bin_count == 2 + assert result.analyzed_frames == 6 + assert result.cluster_labels == ("I", "Pb", "Pb2I") + + label_index = { + label: index for index, label in enumerate(result.cluster_labels) + } + assert result.raw_count_matrix[label_index["Pb2I"], :].tolist() == [ + 2.0, + 1.0, + ] + assert result.raw_count_matrix[label_index["Pb"], :].tolist() == [2.0, 4.0] + assert result.raw_count_matrix[label_index["I"], :].tolist() == [1.0, 2.0] + + pb2i_summary = next( + entry for entry in result.lifetime_by_label if entry.label == "Pb2I" + ) + assert pb2i_summary.cluster_size == 3 + assert pb2i_summary.completed_lifetime_count == 2 + assert pb2i_summary.window_truncated_lifetime_count == 0 + assert pb2i_summary.mean_lifetime_fs == pytest.approx(15.0) + assert pb2i_summary.std_lifetime_fs == pytest.approx(5.0) + assert pb2i_summary.association_events == 2 + assert pb2i_summary.dissociation_events == 2 + + size1_summary = next( + entry for entry in result.lifetime_by_size if entry.cluster_size == 1 + ) + assert size1_summary.completed_lifetime_count == 3 + assert size1_summary.window_truncated_lifetime_count == 6 + assert size1_summary.mean_lifetime_fs == pytest.approx(10.0) + assert size1_summary.std_lifetime_fs == pytest.approx(0.0) + + assert result.energy_data is not None + x_values, y_values, label = result.energy_series("temperature") + assert len(x_values) == 6 + assert y_values.tolist() == pytest.approx( + [300.0, 301.0, 302.0, 303.0, 304.0, 305.0] + ) + assert label == "Temperature (K)" + + +def test_cluster_dynamics_preview_honors_explicit_time_window(tmp_path): + frames_dir = _build_frames_dir(tmp_path) + workflow = ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=2, + analysis_start_fs=10.0, + analysis_stop_fs=40.0, + ) + + preview = workflow.preview_selection() + + assert preview.selected_frames == 4 + assert preview.first_selected_frame == "frame_0001.xyz" + assert preview.last_selected_frame == "frame_0004.xyz" + assert preview.first_selected_time_fs == pytest.approx(10.0) + assert preview.last_selected_time_fs == pytest.approx(40.0) + assert preview.colormap_timestep_fs == pytest.approx(20.0) + assert preview.bin_count == 2 + + +def test_cluster_dynamics_uses_frame_filename_indices_for_time_axis(tmp_path): + frames_dir = _build_offset_frames_dir(tmp_path) + workflow = ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + frame_timestep_fs=0.5, + frames_per_colormap_timestep=1, + ) + + preview = workflow.preview_selection() + + assert preview.folder_start_time_fs == pytest.approx(847.0) + assert preview.first_frame_time_fs == pytest.approx(933.0) + assert preview.first_selected_source_frame_index == 1866 + assert preview.last_selected_source_frame_index == 1867 + assert preview.time_source_label == "Frame filenames x timestep" + assert any( + "847.000 fs" in message and "933.000 fs" in message + for message in preview.time_warnings + ) + + +def test_cluster_dynamics_dataset_round_trip(tmp_path): + frames_dir = _build_frames_dir(tmp_path) + energy_path = _write_energy_file(tmp_path) + result = ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + shell_levels=(1,), + frame_timestep_fs=10.0, + frames_per_colormap_timestep=3, + energy_file=energy_path, + ).analyze() + + saved = save_cluster_dynamics_dataset( + result, + tmp_path / "cluster_dynamics_saved.json", + analysis_settings={ + "frame_timestep_fs": 10.0, + "frames_per_colormap_timestep": 3, + "project_dir": str(tmp_path), + }, + ) + loaded = load_cluster_dynamics_dataset(saved.dataset_file) + + assert saved.dataset_file.exists() + assert any( + path.name.endswith("_cluster_distribution.csv") + for path in saved.written_files + ) + assert loaded.result.cluster_labels == result.cluster_labels + assert ( + loaded.result.raw_count_matrix.shape == result.raw_count_matrix.shape + ) + assert np.array_equal( + loaded.result.raw_count_matrix, result.raw_count_matrix + ) + assert loaded.analysis_settings["frame_timestep_fs"] == 10.0 + assert loaded.analysis_settings["frames_per_colormap_timestep"] == 3 + assert loaded.result.energy_data is not None + + +def test_cluster_dynamics_main_window_updates_preview_for_xyz_frames( + qapp, + tmp_path, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + + window = ClusterDynamicsMainWindow(initial_frames_dir=frames_dir) + window.time_panel.frame_timestep_spin.setValue(10.0) + window.time_panel.frames_per_colormap_timestep_spin.setValue(3) + + preview_text = window.run_panel.selection_box.toPlainText() + + assert window.trajectory_panel.mode_label.text() == "Mode: XYZ frames" + assert window.run_panel.title() == "Run Analysis" + assert window.dataset_panel.title() == "Saved Results" + assert ( + window.dataset_panel.save_dataset_button.text() + == "Save Current Result" + ) + assert ( + window.dataset_panel.load_dataset_button.text() == "Open Saved Result" + ) + assert window.definitions_panel.title() == "Cluster Definitions (XYZ mode)" + assert window.definitions_panel.parentWidget() is not None + assert "Frames selected: 6" in preview_text + assert "Time bins: 2" in preview_text + assert "Frame timestep: 10.000 fs" in preview_text + assert "Frames per colormap timestep: 3" in preview_text + assert "Colormap timestep: 30.000 fs" in preview_text + assert window.time_panel.colormap_timestep_value.text() == "30.000" + window.close() + + +def test_cluster_dynamics_main_window_inherits_project_dir_and_start_time( + qapp, + tmp_path, +): + del qapp + frames_dir = _build_offset_frames_dir(tmp_path) + + window = ClusterDynamicsMainWindow( + initial_frames_dir=frames_dir, + initial_project_dir=tmp_path, + ) + + preview_text = window.run_panel.selection_box.toPlainText() + + assert window.dataset_panel.project_dir() == tmp_path + assert window.time_panel.frame_timestep_fs() == pytest.approx(0.5) + assert window.time_panel.frames_per_colormap_timestep() == 1 + assert window.time_panel.colormap_timestep_fs() == pytest.approx(0.5) + assert window.time_panel.folder_start_time_fs() == pytest.approx(847.0) + assert "Time source: Frame filenames x timestep" in preview_text + assert "Colormap timestep: 0.500 fs" in preview_text + assert "Source frame index range: 1866 to 1867" in preview_text + window.close() + + +def test_cluster_dynamics_main_window_exports_colormap_and_lifetime_csv( + qapp, + tmp_path, + monkeypatch, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + result = ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + shell_levels=(1,), + frame_timestep_fs=10.0, + frames_per_colormap_timestep=2, + ).analyze() + + window = ClusterDynamicsMainWindow(initial_frames_dir=frames_dir) + window._last_result = result + window.plot_panel.set_result(result) + window.plot_panel.display_mode_combo.setCurrentIndex( + window.plot_panel.display_mode_combo.findData("mean_count") + ) + window.plot_panel.time_unit_combo.setCurrentIndex( + window.plot_panel.time_unit_combo.findData("ps") + ) + + colormap_path = tmp_path / "exported_colormap.csv" + lifetime_path = tmp_path / "exported_lifetime.csv" + selected_paths = iter((str(colormap_path), str(lifetime_path))) + monkeypatch.setattr( + "saxshell.clusterdynamics.ui.main_window.QFileDialog.getSaveFileName", + lambda *args, **kwargs: (next(selected_paths), "CSV Files (*.csv)"), + ) + + window.save_colormap_data() + window.save_lifetime_table() + + with colormap_path.open(newline="", encoding="utf-8") as handle: + colormap_rows = list(csv.DictReader(handle)) + with lifetime_path.open(newline="", encoding="utf-8") as handle: + lifetime_rows = list(csv.DictReader(handle)) + + pb2i_lifetime = next( + row for row in lifetime_rows if row["label"] == "Pb2I" + ) + + assert ( + window.dataset_panel.save_colormap_button.text() + == "Save Colormap Data" + ) + assert ( + window.dataset_panel.save_lifetime_button.text() + == "Save Lifetime Table" + ) + assert len(colormap_rows) == len(result.cluster_labels) * result.bin_count + assert colormap_rows[0]["display_mode"] == "mean_count" + assert colormap_rows[0]["time_unit"] == "ps" + assert "colormap_value" in colormap_rows[0] + assert pb2i_lifetime["mean_lifetime_fs"] == "15" + assert pb2i_lifetime["std_lifetime_fs"] == "5" + window.close() + + +def test_cluster_dynamics_main_window_exports_powerpoint_report( + qapp, + tmp_path, + monkeypatch, +): + pytest.importorskip("pptx") + from pptx import Presentation + + del qapp + frames_dir = _build_frames_dir(tmp_path) + project_dir = tmp_path / "saxs_project" + SAXSProjectManager().create_project(project_dir) + result = ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + shell_levels=(1,), + frame_timestep_fs=10.0, + frames_per_colormap_timestep=2, + ).analyze() + + window = ClusterDynamicsMainWindow( + initial_frames_dir=frames_dir, + initial_project_dir=project_dir, + ) + window._last_result = result + window.plot_panel.set_result(result) + window.run_panel.set_selection_summary( + window._format_preview_text(result.preview) + ) + window._populate_summary_box(result) + + report_path = ( + build_project_paths(project_dir).reports_dir + / f"{project_dir.name}_results.pptx" + ) + captured_default_path: dict[str, str] = {} + + def fake_get_save_file_name(*args, **kwargs): + captured_default_path["value"] = str(args[2]) + return (str(report_path), "PowerPoint Files (*.pptx)") + + monkeypatch.setattr( + "saxshell.clusterdynamics.ui.main_window.QFileDialog.getSaveFileName", + fake_get_save_file_name, + ) + + window.save_powerpoint_report() + + presentation = Presentation(str(report_path)) + + assert captured_default_path["value"] == str(report_path) + assert ( + window.dataset_panel.save_powerpoint_button.text() + == "Save PowerPoint Report" + ) + assert len(presentation.slides) >= 5 + assert "ClusterDynamics Report" in _presentation_text(presentation) + assert "Observed Cluster Lifetimes" in _presentation_text(presentation) + window.close() + + +def test_saxshell_cli_forwards_to_clusterdynamics_subcommand(monkeypatch): + captured: dict[str, object] = {} + + def fake_clusterdynamics_main(argv=None): + captured["argv"] = argv + return 27 + + monkeypatch.setattr( + clusterdynamics_cli_module, + "main", + fake_clusterdynamics_main, + ) + + exit_code = saxshell_module.main( + ["clusterdynamics", "--", "frames", "--energy-file", "traj.ener"] + ) + + assert exit_code == 27 + assert captured["argv"] == ["frames", "--energy-file", "traj.ener"] diff --git a/tests/test_clusterdynamicsml.py b/tests/test_clusterdynamicsml.py new file mode 100644 index 0000000..84755df --- /dev/null +++ b/tests/test_clusterdynamicsml.py @@ -0,0 +1,1974 @@ +from __future__ import annotations + +import csv +import json +import os +from pathlib import Path + +import numpy as np +import pytest +from PySide6.QtWidgets import QApplication + +import saxshell.clusterdynamicsml.cli as clusterdynamicsml_cli_module +import saxshell.clusterdynamicsml.workflow as clusterdynamicsml_workflow_module +from saxshell import saxshell as saxshell_module +from saxshell.clusterdynamicsml import ( + ClusterDynamicsMLWorkflow, + load_cluster_dynamicsai_dataset, + save_cluster_dynamicsai_dataset, +) +from saxshell.clusterdynamicsml.ui.main_window import ( + ClusterDynamicsMLMainWindow, + ClusterDynamicsMLSettingsPanel, + _combined_model_weight_rows, +) +from saxshell.clusterdynamicsml.ui.plot_panel import ( + _build_population_histogram_payload, + _distribution_entries, +) +from saxshell.saxs.project_manager import ( + SAXSProjectManager, + build_project_paths, +) +from saxshell.saxs.project_manager.prior_plot import ( + build_prior_histogram_export_payload, + list_secondary_filter_elements, +) + +ATOM_TYPE_DEFINITIONS = { + "node": [("Pb", None)], + "linker": [("I", None)], +} +PAIR_CUTOFFS = { + ("Pb", "I"): {0: 1.2}, +} + + +@pytest.fixture(scope="module") +def qapp(): + os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") + app = QApplication.instance() + if app is None: + app = QApplication([]) + yield app + + +def _disconnected_xyz_lines() -> str: + return ( + "5\n" + "frame_disconnected\n" + "Pb 0.0 0.0 0.0\n" + "Pb 10.0 0.0 0.0\n" + "Pb 20.0 0.0 0.0\n" + "I 30.0 0.0 0.0\n" + "I 40.0 0.0 0.0\n" + ) + + +def _pair_xyz_lines() -> str: + return ( + "5\n" + "frame_pair\n" + "Pb 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "Pb 10.0 0.0 0.0\n" + "I 1.0 0.0 0.0\n" + "I 40.0 0.0 0.0\n" + ) + + +def _triple_xyz_lines() -> str: + return ( + "5\n" + "frame_triple\n" + "Pb 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "Pb 4.0 0.0 0.0\n" + "I 1.0 0.0 0.0\n" + "I 3.0 0.0 0.0\n" + ) + + +def _build_frames_dir(tmp_path: Path) -> Path: + frames_dir = tmp_path / "splitxyz_f0fs" + frames_dir.mkdir() + sequence = ( + _disconnected_xyz_lines(), + _pair_xyz_lines(), + _triple_xyz_lines(), + _triple_xyz_lines(), + _pair_xyz_lines(), + _disconnected_xyz_lines(), + ) + for index, content in enumerate(sequence): + (frames_dir / f"frame_{index:04d}.xyz").write_text(content) + return frames_dir + + +def _build_clusters_dir(tmp_path: Path) -> Path: + clusters_dir = tmp_path / "clusters_training" + (clusters_dir / "Pb").mkdir(parents=True) + (clusters_dir / "Pb2I").mkdir(parents=True) + (clusters_dir / "Pb3I2").mkdir(parents=True) + single = "1\n" "pb_single\n" "Pb 0.0 0.0 0.0\n" + pair_a = ( + "3\n" + "pb2i_pair_a\n" + "Pb 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "I 1.0 0.0 0.0\n" + ) + pair_b = ( + "3\n" + "pb2i_pair_b\n" + "Pb 0.0 0.1 0.0\n" + "Pb 2.1 -0.1 0.0\n" + "I 1.0 0.2 0.1\n" + ) + triple_a = ( + "5\n" + "pb3i2_chain_a\n" + "Pb 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "Pb 4.0 0.0 0.0\n" + "I 1.0 0.2 0.0\n" + "I 3.0 -0.2 0.0\n" + ) + triple_b = ( + "5\n" + "pb3i2_chain_b\n" + "Pb 0.0 0.0 0.0\n" + "Pb 2.1 0.1 0.0\n" + "Pb 4.2 0.0 0.1\n" + "I 1.1 0.3 0.1\n" + "I 3.1 -0.2 -0.1\n" + ) + (clusters_dir / "Pb" / "pb_0001.xyz").write_text(single) + (clusters_dir / "Pb" / "pb_0002.xyz").write_text(single) + (clusters_dir / "Pb2I" / "pb2i_0001.xyz").write_text(pair_a) + (clusters_dir / "Pb2I" / "pb2i_0002.xyz").write_text(pair_b) + (clusters_dir / "Pb3I2" / "pb3i2_0001.xyz").write_text(triple_a) + (clusters_dir / "Pb3I2" / "pb3i2_0002.xyz").write_text(triple_b) + return clusters_dir + + +def _build_node_only_label_clusters_dir(tmp_path: Path) -> Path: + clusters_dir = tmp_path / "clusters_training_node_only_labels" + (clusters_dir / "Pb").mkdir(parents=True) + (clusters_dir / "Pb2").mkdir(parents=True) + (clusters_dir / "Pb3").mkdir(parents=True) + single = "1\n" "pb_single\n" "Pb 0.0 0.0 0.0\n" + pair = ( + "3\n" + "pb2i_pair\n" + "Pb 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "I 1.0 0.0 0.0\n" + ) + triple = ( + "5\n" + "pb3i2_chain\n" + "Pb 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "Pb 4.0 0.0 0.0\n" + "I 1.0 0.2 0.0\n" + "I 3.0 -0.2 0.0\n" + ) + (clusters_dir / "Pb" / "pb_0001.xyz").write_text(single) + (clusters_dir / "Pb2" / "pb2_0001.xyz").write_text(pair) + (clusters_dir / "Pb2" / "pb2_0002.xyz").write_text(pair) + (clusters_dir / "Pb3" / "pb3_0001.xyz").write_text(triple) + (clusters_dir / "Pb3" / "pb3_0002.xyz").write_text(triple) + return clusters_dir + + +def _build_clusters_dir_with_secondary_atoms(tmp_path: Path) -> Path: + clusters_dir = tmp_path / "clusters_training_secondary" + (clusters_dir / "Pb").mkdir(parents=True) + (clusters_dir / "Pb2I").mkdir(parents=True) + (clusters_dir / "Pb3I2").mkdir(parents=True) + single_a = "1\n" "pb_single_a\n" "Pb 0.0 0.0 0.0\n" + single_b = "2\n" "pb_single_b_o\n" "Pb 0.0 0.0 0.0\n" "O 1.5 0.0 0.0\n" + pair_a = ( + "3\n" + "pb2i_pair_a\n" + "Pb 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "I 1.0 0.0 0.0\n" + ) + pair_b = ( + "5\n" + "pb2i_pair_b_o2\n" + "Pb 0.0 0.1 0.0\n" + "Pb 2.1 -0.1 0.0\n" + "I 1.0 0.2 0.1\n" + "O 1.0 1.8 0.0\n" + "O 1.0 -1.8 0.0\n" + ) + triple_a = ( + "6\n" + "pb3i2_chain_a_o\n" + "Pb 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "Pb 4.0 0.0 0.0\n" + "I 1.0 0.2 0.0\n" + "I 3.0 -0.2 0.0\n" + "O 2.0 2.0 0.0\n" + ) + triple_b = ( + "5\n" + "pb3i2_chain_b\n" + "Pb 0.0 0.0 0.0\n" + "Pb 2.1 0.1 0.0\n" + "Pb 4.2 0.0 0.1\n" + "I 1.1 0.3 0.1\n" + "I 3.1 -0.2 -0.1\n" + ) + (clusters_dir / "Pb" / "pb_0001.xyz").write_text(single_a) + (clusters_dir / "Pb" / "pb_0002.xyz").write_text(single_b) + (clusters_dir / "Pb2I" / "pb2i_0001.xyz").write_text(pair_a) + (clusters_dir / "Pb2I" / "pb2i_0002.xyz").write_text(pair_b) + (clusters_dir / "Pb3I2" / "pb3i2_0001.xyz").write_text(triple_a) + (clusters_dir / "Pb3I2" / "pb3i2_0002.xyz").write_text(triple_b) + return clusters_dir + + +def _write_experimental_data_file(tmp_path: Path) -> Path: + q_values = np.linspace(0.05, 1.0, 60) + intensities = np.exp(-2.0 * q_values) + 0.15 + output = tmp_path / "experimental.txt" + with output.open("w", encoding="utf-8") as handle: + for q_value, intensity in zip(q_values, intensities, strict=False): + handle.write(f"{q_value:.6f} {intensity:.8f}\n") + return output + + +def _build_project_dir( + tmp_path: Path, + *, + clusters_dir: Path, + experimental_data_file: Path, +) -> Path: + manager = SAXSProjectManager() + project_dir = tmp_path / "saxs_project" + settings = manager.create_project(project_dir) + settings.clusters_dir = str(clusters_dir) + settings.experimental_data_path = str(experimental_data_file) + settings.copied_experimental_data_file = str(experimental_data_file) + manager.save_project(settings) + return project_dir + + +def _write_component_profile_file( + output_path: Path, + *, + q_values: np.ndarray, + intensity: np.ndarray, +) -> None: + data = np.column_stack( + [ + np.asarray(q_values, dtype=float), + np.asarray(intensity, dtype=float), + np.zeros_like(intensity, dtype=float), + np.zeros_like(intensity, dtype=float), + ] + ) + np.savetxt( + output_path, + data, + comments="", + header="# Columns: q, S(q)_avg, S(q)_std, S(q)_se\n", + fmt=["%.8f", "%.8f", "%.8f", "%.8f"], + ) + + +def _write_project_component_artifacts( + project_dir: Path, + *, + q_values: np.ndarray, +) -> None: + paths = build_project_paths(project_dir) + component_specs = { + "Pb": {"motif_0000": np.full_like(q_values, 1.25, dtype=float)}, + "Pb2I": {"motif_0000": np.full_like(q_values, 2.75, dtype=float)}, + "Pb3I2": {"motif_0000": np.full_like(q_values, 4.50, dtype=float)}, + } + saxs_map: dict[str, dict[str, str]] = {} + prior_structures: dict[str, dict[str, object]] = {} + for structure, motifs in component_specs.items(): + saxs_map[structure] = {} + prior_structures[structure] = {} + motif_count = len(motifs) + for motif, intensity in motifs.items(): + filename = f"{structure}_{motif}.txt" + _write_component_profile_file( + paths.scattering_components_dir / filename, + q_values=q_values, + intensity=intensity, + ) + saxs_map[structure][motif] = filename + prior_structures[structure][motif] = { + "count": 1, + "weight": 1.0 / motif_count, + "profile_file": filename, + } + (project_dir / "md_saxs_map.json").write_text( + json.dumps({"saxs_map": saxs_map}, indent=2) + "\n", + encoding="utf-8", + ) + (project_dir / "md_prior_weights.json").write_text( + json.dumps( + { + "structures": prior_structures, + "q_range": { + "qmin": float(q_values.min()), + "qmax": float(q_values.max()), + "points": int(q_values.size), + }, + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + + +def _presentation_text(presentation) -> str: + texts: list[str] = [] + for slide in presentation.slides: + for shape in slide.shapes: + if getattr(shape, "has_text_frame", False): + texts.append(shape.text) + return "\n".join(texts) + + +def test_clusterdynamicsml_prediction_panel_tooltips(qapp): + del qapp + panel = ClusterDynamicsMLSettingsPanel() + + assert "stoichiometry label" in panel.clusters_dir_edit.toolTip() + assert "Leave this blank" in panel.experimental_data_edit.toolTip() + assert "Lowest node count to predict" in panel.target_start_spin.toolTip() + assert "Every integer node count" in panel.target_stop_spin.toolTip() + assert ( + "ranked candidate stoichiometries" in panel.candidates_spin.toolTip() + ) + assert ( + "share among the predicted candidates" + in panel.share_threshold_spin.toolTip() + ) + assert "no experimental data file is loaded" in panel.q_min_spin.toolTip() + assert "no experimental data file is loaded" in panel.q_max_spin.toolTip() + assert "fallback SAXS grid" in panel.q_points_spin.toolTip() + assert "larger node counts to predict" in panel.toolTip() + + panel.close() + + +def test_clusterdynamicsml_workflow_predicts_larger_clusters(tmp_path): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + project_dir = _build_project_dir( + tmp_path, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + ) + + workflow = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + project_dir=project_dir, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ) + + preview = workflow.preview_selection() + result = workflow.analyze() + + assert preview.structure_label_count == 3 + assert preview.observed_node_counts == (1, 2, 3) + assert preview.target_node_counts == (4, 5) + + assert result.max_observed_node_count == 3 + assert result.max_predicted_node_count in {4, 5} + assert {entry.target_node_count for entry in result.predictions} == {4, 5} + assert all( + entry.label.startswith("Pb4") or entry.label.startswith("Pb5") + for entry in result.predictions + ) + assert {entry.label for entry in result.predictions}.isdisjoint( + {"Pb4", "Pb5"} + ) + assert sum( + entry.predicted_population_share for entry in result.predictions + ) == pytest.approx(1.0) + assert all( + entry.predicted_population_share > 0.0 for entry in result.predictions + ) + assert all( + len(entry.generated_elements) == sum(entry.element_counts.values()) + for entry in result.predictions + ) + assert result.saxs_comparison is not None + assert ( + result.saxs_comparison.experimental_data_path + == experimental_data_file.resolve() + ) + assert result.saxs_comparison.rmse is not None + assert len(result.saxs_comparison.component_weights) >= 3 + + +def test_clusterdynamicsml_structure_observations_use_reference_atom_counts( + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_node_only_label_clusters_dir(tmp_path) + + workflow = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4,), + ) + + observations = workflow._build_structure_observations(clusters_dir) + counts_by_label = { + entry.label: entry.element_counts for entry in observations + } + + assert counts_by_label["Pb"] == {"Pb": 1} + assert counts_by_label["Pb2"] == {"Pb": 2, "I": 1} + assert counts_by_label["Pb3"] == {"Pb": 3, "I": 2} + + +def test_clusterdynamicsml_learns_reference_geometry_statistics(tmp_path): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir_with_secondary_atoms(tmp_path) + + workflow = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions={ + **ATOM_TYPE_DEFINITIONS, + "solvent": [("O", None)], + }, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4,), + ) + result = workflow.analyze() + statistics = workflow._collect_training_geometry_statistics( + list(result.training_observations) + ) + + assert statistics.atom_type_by_element["Pb"] == "node" + assert statistics.atom_type_by_element["I"] == "linker" + assert statistics.atom_type_by_element["O"] == "shell" + assert statistics.node_bond_length == pytest.approx(2.05, abs=0.20) + assert statistics.bond_length_medians[("I", "Pb")] == pytest.approx( + 1.03, + abs=0.20, + ) + assert statistics.node_angle_medians[("node", "node")] > 170.0 + assert statistics.node_coordination_medians["node"] == pytest.approx( + 1.0, + abs=0.1, + ) + assert statistics.node_coordination_medians["linker"] == pytest.approx( + 1.0, + abs=0.1, + ) + assert statistics.contact_distance_medians[("I", "I")] == pytest.approx( + 2.05, + abs=0.20, + ) + assert statistics.contact_distance_medians[("I", "O")] == pytest.approx( + 2.03, + abs=0.30, + ) + assert statistics.non_node_node_coordination_medians["I"] == pytest.approx( + 2.0, + abs=0.1, + ) + assert statistics.non_node_node_coordination_medians["O"] == pytest.approx( + 1.0, + abs=0.1, + ) + assert statistics.atom_coordination_medians[ + ("shell", "linker") + ] == pytest.approx( + 1.0, + abs=0.1, + ) + assert statistics.atom_coordination_medians[("linker", "linker")] > 0.0 + assert statistics.atom_coordination_medians[("linker", "shell")] > 0.0 + + +def test_clusterdynamicsml_predicted_structure_follows_reference_geometry( + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + + workflow = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4,), + ) + result = workflow.analyze() + statistics = workflow._collect_training_geometry_statistics( + list(result.training_observations) + ) + prediction = max( + ( + entry + for entry in result.predictions + if entry.target_node_count == 4 + and entry.element_counts.get("I", 0) > 0 + ), + key=lambda entry: entry.predicted_population_share, + ) + + node_indices = [ + index + for index, element in enumerate(prediction.generated_elements) + if element == "Pb" + ] + linker_indices = [ + index + for index, element in enumerate(prediction.generated_elements) + if element == "I" + ] + node_coordinates = np.asarray( + prediction.generated_coordinates[node_indices], + dtype=float, + ) + linker_coordinates = np.asarray( + prediction.generated_coordinates[linker_indices], + dtype=float, + ) + scaffold_edges = clusterdynamicsml_workflow_module._node_scaffold_edges( + node_coordinates, + ["Pb"] * len(node_coordinates), + pair_cutoff_definitions=PAIR_CUTOFFS, + ) + edge_lengths = np.asarray( + [ + np.linalg.norm( + node_coordinates[index_a] - node_coordinates[index_b] + ) + for index_a, index_b in scaffold_edges + ], + dtype=float, + ) + scaffold_adjacency = ( + clusterdynamicsml_workflow_module._adjacency_from_edges( + len(node_coordinates), + scaffold_edges, + ) + ) + node_angles = [ + clusterdynamicsml_workflow_module._angle_between_vectors( + node_coordinates[neighbors[0]] - node_coordinates[node_index], + node_coordinates[neighbors[1]] - node_coordinates[node_index], + ) + for node_index, neighbors in ( + (index, sorted(entries)) + for index, entries in scaffold_adjacency.items() + if len(entries) == 2 + ) + ] + pb_i_distance = statistics.bond_length_medians[("I", "Pb")] + linker_distances = np.asarray( + [ + np.linalg.norm(node_coordinates - coordinate, axis=1) + for coordinate in linker_coordinates + ], + dtype=float, + ) + linker_bridge_counts = [ + int(np.sum(distances <= pb_i_distance * 1.25)) + for distances in linker_distances + ] + + assert len(scaffold_edges) == len(node_coordinates) - 1 + assert np.median(edge_lengths) == pytest.approx( + statistics.node_bond_length, + abs=0.35, + ) + assert node_angles + assert min(angle for angle in node_angles if angle is not None) > 140.0 + assert np.median(np.min(linker_distances, axis=1)) == pytest.approx( + pb_i_distance, + abs=0.25, + ) + assert max(linker_bridge_counts) >= 2 + + +def test_clusterdynamicsml_predicted_structure_respects_non_node_contacts( + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir_with_secondary_atoms(tmp_path) + + workflow = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions={ + **ATOM_TYPE_DEFINITIONS, + "solvent": [("O", None)], + }, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4,), + ) + result = workflow.analyze() + statistics = workflow._collect_training_geometry_statistics( + list(result.training_observations) + ) + source_observation = max( + result.training_observations, + key=lambda row: row.node_count, + ) + generated_elements, generated_coordinates = ( + workflow._generate_predicted_structure( + source_observation, + target_counts={"Pb": 4, "I": 3, "O": 2}, + predicted_max_radius=max( + source_observation.mean_max_radius * 1.3, 1.0 + ), + geometry_statistics=statistics, + ) + ) + + coordinates = np.asarray(generated_coordinates, dtype=float) + linker_indices = [ + index + for index, element in enumerate(generated_elements) + if element == "I" + ] + shell_indices = [ + index + for index, element in enumerate(generated_elements) + if element == "O" + ] + linker_coordinates = np.asarray(coordinates[linker_indices], dtype=float) + shell_coordinates = np.asarray(coordinates[shell_indices], dtype=float) + linker_linker_distances = [ + float( + np.linalg.norm( + linker_coordinates[index_a] - linker_coordinates[index_b] + ) + ) + for index_a in range(len(linker_coordinates)) + for index_b in range(index_a + 1, len(linker_coordinates)) + ] + shell_to_linker_distances = [ + float(np.min(np.linalg.norm(linker_coordinates - coordinate, axis=1))) + for coordinate in shell_coordinates + ] + linker_shell_cutoff = ( + clusterdynamicsml_workflow_module._contact_distance_cutoff( + "I", + "O", + geometry_statistics=statistics, + default_distance=statistics.node_bond_length, + pair_cutoff_definitions=PAIR_CUTOFFS, + ) + ) + linker_linker_cutoff = ( + clusterdynamicsml_workflow_module._contact_distance_cutoff( + "I", + "I", + geometry_statistics=statistics, + default_distance=statistics.node_bond_length, + pair_cutoff_definitions=PAIR_CUTOFFS, + ) + ) + linker_linker_contact_counts = [ + int( + np.sum( + np.linalg.norm(linker_coordinates - coordinate, axis=1) + <= linker_linker_cutoff + ) + - 1 + ) + for coordinate in linker_coordinates + ] + + assert linker_indices + assert shell_indices + assert min(linker_linker_distances) == pytest.approx( + statistics.contact_distance_medians[("I", "I")], + abs=0.35, + ) + assert np.median( + np.asarray(shell_to_linker_distances, dtype=float) + ) == pytest.approx( + statistics.contact_distance_medians[("I", "O")], + abs=0.35, + ) + assert all( + distance <= linker_shell_cutoff + for distance in shell_to_linker_distances + ) + assert sum(count >= 1 for count in linker_linker_contact_counts) >= 2 + + +def test_clusterdynamicsml_histogram_binning_keeps_secondary_oxygen_out_of_labels( + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir_with_secondary_atoms(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions={ + **ATOM_TYPE_DEFINITIONS, + "solvent": [("O", None)], + }, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + oxygen_predictions = [ + entry + for entry in result.predictions + if entry.element_counts.get("O", 0) > 0 + ] + assert oxygen_predictions + + payload = _build_population_histogram_payload( + result, + include_predictions=True, + ) + + assert payload is not None + assert "O" in payload["available_elements"] + assert all("O" not in label for label in payload["structures"]) + + predicted_payloads = [ + motif_payload + for motifs in payload["structures"].values() + for motif_name, motif_payload in motifs.items() + if motif_name.startswith("predicted_rank_") + ] + assert any( + "O" in payload.get("secondary_atom_distributions", {}) + and any( + int(segment) > 0 + for segment in payload["secondary_atom_distributions"]["O"] + ) + for payload in predicted_payloads + ) + + +def test_clusterdynamicsml_solvent_sort_histograms_stay_normalized_with_predictions( + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir_with_secondary_atoms(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions={ + **ATOM_TYPE_DEFINITIONS, + "solvent": [("O", None)], + }, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + payload = _build_population_histogram_payload( + result, + include_predictions=True, + ) + + assert payload is not None + + structure_fraction = build_prior_histogram_export_payload( + payload, + mode="solvent_sort_structure_fraction", + value_mode="percent", + secondary_element="O", + ) + atom_fraction = build_prior_histogram_export_payload( + payload, + mode="solvent_sort_atom_fraction", + value_mode="percent", + secondary_element="O", + ) + + assert float(np.sum(structure_fraction["totals"])) == pytest.approx(100.0) + assert float(np.sum(atom_fraction["totals"])) == pytest.approx(100.0) + assert float(np.max(structure_fraction["totals"])) <= 100.0 + 1.0e-9 + assert float(np.max(atom_fraction["totals"])) <= 100.0 + 1.0e-9 + + +def test_clusterdynamicsml_predictions_keep_required_linkers_and_drop_tiny_tail( + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir_with_secondary_atoms(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions={ + **ATOM_TYPE_DEFINITIONS, + "solvent": [("O", None)], + }, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + minimum_iodides = {4: 2, 5: 3} + + assert result.predictions + assert {entry.target_node_count for entry in result.predictions} == {4, 5} + assert {entry.label for entry in result.predictions}.isdisjoint( + {"Pb4O4", "Pb5O5"} + ) + assert all( + entry.element_counts.get("I", 0) + >= minimum_iodides.get(entry.target_node_count, 0) + for entry in result.predictions + ) + assert all( + entry.predicted_population_share + >= result.prediction_population_share_threshold + for entry in result.predictions + ) + + +def test_clusterdynamicsml_secondary_atom_predictions_preserve_reference_bonds( + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir_with_secondary_atoms(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + workflow = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions={ + **ATOM_TYPE_DEFINITIONS, + "solvent": [("O", None)], + }, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + prediction_population_share_threshold=0.0, + ) + result = workflow.analyze() + statistics = workflow._collect_training_geometry_statistics( + list(result.training_observations) + ) + + assert any( + entry.element_counts.get("O", 0) > 0 for entry in result.predictions + ) + + for prediction in result.predictions: + node_indices = [ + index + for index, element in enumerate(prediction.generated_elements) + if element == "Pb" + ] + node_coordinates = np.asarray( + prediction.generated_coordinates[node_indices], + dtype=float, + ) + scaffold_edges = ( + clusterdynamicsml_workflow_module._node_scaffold_edges( + node_coordinates, + ["Pb"] * len(node_coordinates), + pair_cutoff_definitions=PAIR_CUTOFFS, + ) + ) + edge_lengths = np.asarray( + [ + np.linalg.norm( + node_coordinates[index_a] - node_coordinates[index_b] + ) + for index_a, index_b in scaffold_edges + ], + dtype=float, + ) + assert edge_lengths.size > 0 + assert np.median(edge_lengths) == pytest.approx( + statistics.node_bond_length, + abs=0.20, + ) + + for element in {"I", "O"}: + atom_indices = [ + index + for index, atom_element in enumerate( + prediction.generated_elements + ) + if atom_element == element + ] + if not atom_indices: + continue + nearest_node_distances = np.asarray( + [ + np.min( + np.linalg.norm( + node_coordinates + - prediction.generated_coordinates[atom_index], + axis=1, + ) + ) + for atom_index in atom_indices + ], + dtype=float, + ) + assert np.median(nearest_node_distances) == pytest.approx( + statistics.bond_length_medians[(element, "Pb")], + abs=0.20, + ) + + +def test_clusterdynamicsml_reuses_project_component_profiles_for_observed_saxs( + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + project_dir = _build_project_dir( + tmp_path, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + ) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.experimental_data_path = None + settings.copied_experimental_data_file = None + manager.save_project(settings) + q_values = np.linspace(0.20, 0.60, 9) + _write_project_component_artifacts(project_dir, q_values=q_values) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + project_dir=project_dir, + experimental_data_file=None, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + q_points=q_values.size, + ).analyze() + + assert result.saxs_comparison is not None + observed_entries = [ + entry + for entry in result.saxs_comparison.component_weights + if entry.source == "observed_project" + ] + assert len(observed_entries) == len(result.training_observations) + assert result.saxs_comparison.q_values[0] == pytest.approx( + float(q_values[0]) + ) + assert result.saxs_comparison.q_values[-1] == pytest.approx( + float(q_values[-1]) + ) + + expected_weights: list[float] = [] + expected_traces: list[np.ndarray] = [] + for row in result.training_observations: + profile_entry = next( + entry for entry in observed_entries if entry.label == row.label + ) + assert profile_entry.profile_path is not None + profile_data = np.loadtxt(profile_entry.profile_path, comments="#") + expected_traces.append(np.asarray(profile_data[:, 1], dtype=float)) + expected_weights.append( + max(float(row.mean_count_per_frame), 0.0) + * max(float(row.occupancy_fraction), 0.05) + ) + normalized = np.asarray(expected_weights, dtype=float) + normalized = normalized / np.sum(normalized) + expected_model = np.einsum( + "i,ij->j", normalized, np.asarray(expected_traces) + ) + assert np.allclose( + result.saxs_comparison.observed_raw_model_intensity, + expected_model, + ) + + +def test_clusterdynamicsml_maps_prediction_shares_into_combined_weights( + tmp_path, + monkeypatch, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + original_fit = ( + clusterdynamicsml_workflow_module.ClusterDynamicsMLWorkflow._fit_candidate_property_models + ) + + def _collapsed_count_models( + self, + training_observations, + feature_matrix, + weights, + *, + non_node_elements, + node_elements, + ): + models = original_fit( + self, + training_observations, + feature_matrix, + weights, + non_node_elements=non_node_elements, + node_elements=node_elements, + ) + models["mean_count_per_frame"] = ( + clusterdynamicsml_workflow_module._PropertyModel( + coefficients=None, + constant_value=0.0, + transform="identity", + default_value=0.0, + lower_bound=0.0, + ) + ) + models["occupancy_fraction"] = ( + clusterdynamicsml_workflow_module._PropertyModel( + coefficients=None, + constant_value=0.35, + transform="identity", + default_value=0.35, + lower_bound=0.0, + upper_bound=1.0, + ) + ) + models["mean_lifetime_fs"] = ( + clusterdynamicsml_workflow_module._PropertyModel( + coefficients=None, + constant_value=25.0, + transform="identity", + default_value=25.0, + lower_bound=self.frame_timestep_fs, + ) + ) + return models + + monkeypatch.setattr( + clusterdynamicsml_workflow_module.ClusterDynamicsMLWorkflow, + "_fit_candidate_property_models", + _collapsed_count_models, + ) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + assert all( + entry.predicted_mean_count_per_frame == pytest.approx(0.0) + for entry in result.predictions + ) + assert sum( + entry.predicted_population_share for entry in result.predictions + ) == pytest.approx(1.0) + assert result.saxs_comparison is not None + + predicted_component_weights = { + entry.label: float(entry.weight) + for entry in result.saxs_comparison.component_weights + if entry.source == "predicted" + } + assert set(predicted_component_weights) == { + entry.label for entry in result.predictions + } + assert all(weight > 0.0 for weight in predicted_component_weights.values()) + + normalized_predicted_model_weights = np.asarray( + [ + predicted_component_weights[entry.label] + for entry in result.predictions + ], + dtype=float, + ) + normalized_predicted_model_weights = ( + normalized_predicted_model_weights + / np.sum(normalized_predicted_model_weights) + ) + expected_shares = np.asarray( + [entry.predicted_population_share for entry in result.predictions], + dtype=float, + ) + expected_shares = expected_shares / np.sum(expected_shares) + assert np.allclose(normalized_predicted_model_weights, expected_shares) + + combined_entries = _distribution_entries(result, include_predictions=True) + predicted_entries = [ + entry + for entry in combined_entries + if int(entry["node_count"]) > result.max_observed_node_count + ] + assert len(predicted_entries) == len(result.predictions) + assert all( + float(entry["normalized_weight"]) > 0.0 for entry in predicted_entries + ) + + combined_payload = _build_population_histogram_payload( + result, + include_predictions=True, + ) + assert combined_payload is not None + assert {entry.label for entry in result.predictions}.issubset( + set(combined_payload["structures"]) + ) + + +def test_clusterdynamicsml_caps_predicted_weight_takeover_from_extreme_shares( + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + assert result.predictions + result.predictions[0].predicted_population_share = 0.98 + residual_share = 0.02 / max(len(result.predictions) - 1, 1) + for entry in result.predictions[1:]: + entry.predicted_population_share = residual_share + result.predictions[0].predicted_mean_count_per_frame = 1.0e6 + result.predictions[0].predicted_occupancy_fraction = 1.0 + + observed_weights, predicted_weights = ( + clusterdynamicsml_workflow_module._resolved_population_weights( + result.training_observations, + result.predictions, + frame_timestep_fs=10.0, + ) + ) + observed_size_totals: dict[int, float] = {} + for observation, weight in zip( + result.training_observations, + observed_weights, + strict=False, + ): + observed_size_totals[int(observation.node_count)] = ( + observed_size_totals.get(int(observation.node_count), 0.0) + + float(weight) + ) + + assert float(np.sum(predicted_weights)) <= ( + float(observed_size_totals[max(observed_size_totals)]) + 1.0e-12 + ) + combined_total = float( + np.sum(observed_weights) + np.sum(predicted_weights) + ) + assert combined_total > 0.0 + assert float(np.max(predicted_weights) / combined_total) < 0.5 + + +def test_clusterdynamicsml_writes_surrogate_xyz_and_component_profiles( + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + assert result.saxs_comparison is not None + assert result.saxs_comparison.component_output_dir is not None + assert result.saxs_comparison.surrogate_structure_dir is not None + assert result.saxs_comparison.component_output_dir.is_dir() + assert result.saxs_comparison.surrogate_structure_dir.is_dir() + + observed_entries = [ + entry + for entry in result.saxs_comparison.component_weights + if entry.source == "observed_direct" + ] + predicted_entries = [ + entry + for entry in result.saxs_comparison.component_weights + if entry.source == "predicted" + ] + assert observed_entries + assert predicted_entries + assert all( + entry.profile_path is not None and entry.profile_path.is_file() + for entry in observed_entries + predicted_entries + ) + assert all( + entry.structure_path is not None and entry.structure_path.is_file() + for entry in predicted_entries + ) + first_xyz = predicted_entries[0].structure_path.read_text(encoding="utf-8") + assert predicted_entries[0].structure_path.suffix == ".xyz" + assert first_xyz.splitlines()[0].strip().isdigit() + + +def test_clusterdynamicsml_writes_xyz_for_every_prediction( + tmp_path, + monkeypatch, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + original_resolve = ( + clusterdynamicsml_workflow_module._resolved_population_weights + ) + + def _drop_predicted_model_members( + training_observations, + predictions, + *, + frame_timestep_fs, + ): + observed_weights, predicted_weights = original_resolve( + training_observations, + predictions, + frame_timestep_fs=frame_timestep_fs, + ) + filtered = np.zeros_like(predicted_weights) + for index, entry in enumerate(predictions): + if entry.rank == 1: + filtered[index] = predicted_weights[index] + return observed_weights, filtered + + monkeypatch.setattr( + clusterdynamicsml_workflow_module, + "_resolved_population_weights", + _drop_predicted_model_members, + ) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + assert result.saxs_comparison is not None + assert result.saxs_comparison.surrogate_structure_dir is not None + written_paths = sorted( + result.saxs_comparison.surrogate_structure_dir.glob("*.xyz") + ) + + assert len(result.predictions) > len( + {entry.target_node_count for entry in result.predictions} + ) + assert len(written_paths) == len(result.predictions) + assert {path.name for path in written_paths} == { + f"{entry.target_node_count:02d}_rank{entry.rank:02d}_{entry.label}.xyz" + for entry in result.predictions + } + + +def test_clusterdynamicsml_dataset_round_trip(tmp_path): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + saved = save_cluster_dynamicsai_dataset( + result, + tmp_path / "clusterdynamicsml_saved.json", + analysis_settings={ + "frame_timestep_fs": 10.0, + "target_node_counts": [4, 5], + }, + ) + loaded = load_cluster_dynamicsai_dataset(saved.dataset_file) + + assert saved.dataset_file.exists() + assert any( + path.name.endswith("_cluster_distribution.csv") + for path in saved.written_files + ) + assert any( + path.name.endswith("_lifetime.csv") for path in saved.written_files + ) + assert any( + path.name.endswith("_predictions.csv") for path in saved.written_files + ) + assert any( + path.name.endswith("_observed_histogram.csv") + for path in saved.written_files + ) + assert any( + path.name.endswith("_observed_plus_surrogate_histogram.csv") + for path in saved.written_files + ) + assert any(path.suffix == ".xyz" for path in saved.written_files) + assert any( + path.parent.name.endswith("_saxs_components") + for path in saved.written_files + ) + assert ( + loaded.result.max_observed_node_count == result.max_observed_node_count + ) + assert [entry.label for entry in loaded.result.predictions] == [ + entry.label for entry in result.predictions + ] + assert loaded.analysis_settings["target_node_counts"] == [4, 5] + assert loaded.result.saxs_comparison is not None + assert np.allclose( + loaded.result.saxs_comparison.fitted_model_intensity, + result.saxs_comparison.fitted_model_intensity, + ) + assert np.allclose( + loaded.result.saxs_comparison.observed_fitted_model_intensity, + result.saxs_comparison.observed_fitted_model_intensity, + ) + assert ( + loaded.result.saxs_comparison.component_output_dir + == result.saxs_comparison.component_output_dir + ) + + +def test_clusterdynamicsml_window_autosaves_and_restores_project_result_bundle( + qapp, + tmp_path, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + project_dir = _build_project_dir( + tmp_path, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + ) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + project_dir=project_dir, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + window = ClusterDynamicsMLMainWindow( + initial_frames_dir=frames_dir, + initial_project_dir=project_dir, + ) + window.time_panel.set_frame_timestep_fs(10.0) + window.time_panel.set_frames_per_colormap_timestep(1) + window.prediction_panel.set_target_node_counts((4, 5)) + window._on_run_finished(result) + + saved_results_dir = ( + build_project_paths(project_dir).exported_data_dir + / "clusterdynamicsml" + / "saved_results" + ) + dataset_files = sorted(saved_results_dir.rglob("*_clusterdynamicsml.json")) + assert dataset_files + cached_dataset = dataset_files[-1] + bundle_dir = cached_dataset.parent + bundle_files = {path.name for path in bundle_dir.iterdir()} + assert f"{cached_dataset.stem}_selection_preview.txt" in bundle_files + assert f"{cached_dataset.stem}_summary.txt" in bundle_files + assert f"{cached_dataset.stem}_saxs.csv" in bundle_files + assert f"{cached_dataset.stem}_observed_histogram.csv" in bundle_files + assert ( + f"{cached_dataset.stem}_observed_plus_surrogate_histogram.csv" + in bundle_files + ) + assert (bundle_dir / f"{cached_dataset.stem}_saxs_components").is_dir() + + window.close() + + reopened = ClusterDynamicsMLMainWindow(initial_project_dir=project_dir) + + assert reopened._last_result is not None + assert reopened._last_dataset_file == cached_dataset + assert reopened.trajectory_panel.get_frames_dir() == frames_dir + assert reopened.lifetime_table.rowCount() == ( + len(result.training_observations) + len(result.predictions) + ) + assert "SAXS components in mixture" in reopened.summary_box.toPlainText() + observed_hist_patches = sum( + len(axis.patches) for axis in reopened.histogram_panel.figure.axes + ) + assert observed_hist_patches > 0 + reopened.close() + + +def test_clusterdynamicsml_window_compares_prediction_history_and_defaults_to_latest( + qapp, + tmp_path, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + project_dir = _build_project_dir( + tmp_path, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + ) + + result_one = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + project_dir=project_dir, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4,), + prediction_population_share_threshold=0.01, + ).analyze() + result_two = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + project_dir=project_dir, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + prediction_population_share_threshold=0.05, + ).analyze() + + window = ClusterDynamicsMLMainWindow( + initial_frames_dir=frames_dir, + initial_project_dir=project_dir, + ) + window.time_panel.set_frame_timestep_fs(10.0) + window.time_panel.set_frames_per_colormap_timestep(1) + window.prediction_panel.set_target_node_counts((4,)) + window.prediction_panel.set_prediction_population_share_threshold(0.01) + window._on_run_finished(result_one) + first_dataset = window._last_dataset_file + + window.prediction_panel.set_target_node_counts((4, 5)) + window.prediction_panel.set_prediction_population_share_threshold(0.05) + window._on_run_finished(result_two) + second_dataset = window._last_dataset_file + + assert first_dataset is not None + assert second_dataset is not None + assert first_dataset != second_dataset + assert window.history_table.rowCount() == 2 + assert window._selected_history_dataset_file() == second_dataset + assert window._last_result is not None + assert window._last_result.preview.target_node_counts == (4, 5) + + older_row = next( + row + for row in range(window.history_table.rowCount()) + if window._history_dataset_file_for_row(row) == first_dataset + ) + window.history_table.selectRow(older_row) + window._load_selected_history_entry() + + assert window._last_dataset_file == first_dataset + assert window._selected_history_dataset_file() == first_dataset + assert window._last_result is not None + assert window._last_result.preview.target_node_counts == (4,) + window.close() + + reopened = ClusterDynamicsMLMainWindow(initial_project_dir=project_dir) + + assert reopened._last_result is not None + assert reopened._last_dataset_file == second_dataset + assert reopened._selected_history_dataset_file() == second_dataset + assert reopened._last_result.preview.target_node_counts == (4, 5) + reopened.close() + + +def test_clusterdynamicsml_window_inherits_project_defaults( + qapp, + tmp_path, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + project_dir = _build_project_dir( + tmp_path, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + ) + + window = ClusterDynamicsMLMainWindow( + initial_frames_dir=frames_dir, + initial_project_dir=project_dir, + ) + window.time_panel.set_frame_timestep_fs(10.0) + window.time_panel.set_frames_per_colormap_timestep(1) + window.prediction_panel.set_target_node_counts((4, 5)) + window._refresh_selection_preview() + preview_text = window.run_panel.selection_box.toPlainText() + + assert window.dataset_panel.project_dir() == project_dir + assert window.prediction_panel.clusters_dir() == clusters_dir + assert ( + window.prediction_panel.experimental_data_file() + == experimental_data_file + ) + assert "Observed node counts: (1, 2, 3)" in preview_text + assert "Target node counts: (4, 5)" in preview_text + window.close() + + +def test_clusterdynamicsml_window_exports_colormap_and_lifetime_csv( + qapp, + tmp_path, + monkeypatch, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + window = ClusterDynamicsMLMainWindow(initial_frames_dir=frames_dir) + window._last_result = result + window.dynamics_plot_panel.set_result(result.dynamics_result) + window.dynamics_plot_panel.display_mode_combo.setCurrentIndex( + window.dynamics_plot_panel.display_mode_combo.findData("count") + ) + window.dynamics_plot_panel.time_unit_combo.setCurrentIndex( + window.dynamics_plot_panel.time_unit_combo.findData("ps") + ) + + colormap_path = tmp_path / "ai_colormap.csv" + lifetime_path = tmp_path / "ai_lifetime.csv" + selected_paths = iter((str(colormap_path), str(lifetime_path))) + monkeypatch.setattr( + "saxshell.clusterdynamicsml.ui.main_window.QFileDialog.getSaveFileName", + lambda *args, **kwargs: (next(selected_paths), "CSV Files (*.csv)"), + ) + + window.save_colormap_data() + window.save_lifetime_table() + + with colormap_path.open(newline="", encoding="utf-8") as handle: + colormap_rows = list(csv.DictReader(handle)) + with lifetime_path.open(newline="", encoding="utf-8") as handle: + lifetime_rows = list(csv.DictReader(handle)) + + assert len(colormap_rows) == ( + len(result.dynamics_result.cluster_labels) + * result.dynamics_result.bin_count + ) + assert colormap_rows[0]["display_mode"] == "count" + assert colormap_rows[0]["time_unit"] == "ps" + assert any(row["Label"] == "Pb3I2" for row in lifetime_rows) + assert any(row["Type"] == "Predicted" for row in lifetime_rows) + window.close() + + +def test_clusterdynamicsml_window_shows_observed_lifetime_tab( + qapp, + tmp_path, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + window = ClusterDynamicsMLMainWindow(initial_frames_dir=frames_dir) + window._on_run_finished(result) + + tab_titles = [ + window.results_tabs.tabText(index) + for index in range(window.results_tabs.count()) + ] + lifetime_labels = [ + window.lifetime_table.item(row, 4).text() + for row in range(window.lifetime_table.rowCount()) + ] + lifetime_types = [ + window.lifetime_table.item(row, 0).text() + for row in range(window.lifetime_table.rowCount()) + ] + observed_only_weights = [ + window.lifetime_table.item(row, 5).text() + for row in range(window.lifetime_table.rowCount()) + if window.lifetime_table.item(row, 0).text() == "Observed" + ] + + assert tab_titles == ["Summary", "Lifetimes", "Histograms", "SAXS"] + assert window.lifetime_table.rowCount() == ( + len(result.training_observations) + len(result.predictions) + ) + assert "Pb3I2" in lifetime_labels + assert "Predicted" in lifetime_types + assert all(weight != "n/a" for weight in observed_only_weights) + assert window.lifetime_table.item(0, 8) is not None + window.close() + + +def test_clusterdynamicsml_window_shows_histogram_tabs_and_saxs_model_overlay( + qapp, + tmp_path, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + window = ClusterDynamicsMLMainWindow(initial_frames_dir=frames_dir) + window._on_run_finished(result) + + tab_titles = [ + window.results_tabs.tabText(index) + for index in range(window.results_tabs.count()) + ] + observed_hist_patches = sum( + len(axis.patches) for axis in window.histogram_panel.figure.axes + ) + window.histogram_panel.population_combo.setCurrentIndex(1) + combined_hist_patches = sum( + len(axis.patches) for axis in window.histogram_panel.figure.axes + ) + observed_weight_sum = sum( + float(entry["normalized_weight"]) + for entry in _distribution_entries(result, include_predictions=False) + ) + combined_weight_sum = sum( + float(entry["normalized_weight"]) + for entry in _distribution_entries(result, include_predictions=True) + ) + line_labels = { + line.get_label() + for axis in window.saxs_panel.figure.axes + for line in axis.lines + } + surrogate_axes = window.saxs_panel.figure.axes + prediction_size_ranks: dict[int, set[int]] = {} + for row in range(window.lifetime_table.rowCount()): + if window.lifetime_table.item(row, 0).text() != "Predicted": + continue + node_count = int(window.lifetime_table.item(row, 1).text()) + size_rank = int(window.lifetime_table.item(row, 2).text()) + prediction_size_ranks.setdefault(node_count, set()).add(size_rank) + model_rows = _combined_model_weight_rows(result) + combined_weight_percents = [ + float(window.lifetime_table.item(row, 6).text()) + for row in range(window.lifetime_table.rowCount()) + ] + window.saxs_panel._toggle_all_component_traces() + component_line_visibility = { + line.get_label(): line.get_visible() + for axis in window.saxs_panel.figure.axes + for line in axis.lines + if "component:" in str(line.get_gid() or "") + } + + assert tab_titles == ["Summary", "Lifetimes", "Histograms", "SAXS"] + assert observed_hist_patches > 0 + assert combined_hist_patches > 0 + assert observed_weight_sum == pytest.approx(1.0) + assert combined_weight_sum == pytest.approx(1.0) + assert len(surrogate_axes) == 1 + assert surrogate_axes[0].get_xscale() == "log" + assert surrogate_axes[0].get_yscale() == "log" + assert prediction_size_ranks == {4: {2}, 5: {1}} + assert window.lifetime_table.rowCount() == ( + len(result.training_observations) + len(result.predictions) + ) + assert sum( + row["normalized_weight"] for row in model_rows + ) == pytest.approx(1.0) + assert sum(combined_weight_percents) == pytest.approx(100.0, abs=0.2) + assert "observed-only model" in line_labels + assert "observed + surrogate model" in line_labels + assert any("surrogate component:" in label for label in line_labels) + assert component_line_visibility + assert not any(component_line_visibility.values()) + window.close() + + +def test_clusterdynamicsml_histogram_tabs_match_project_setup_modes( + qapp, + tmp_path, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir_with_secondary_atoms(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + observed_payload = _build_population_histogram_payload( + result, + include_predictions=False, + ) + combined_payload = _build_population_histogram_payload( + result, + include_predictions=True, + ) + + assert observed_payload is not None + assert combined_payload is not None + assert list_secondary_filter_elements(observed_payload) == ["O"] + assert set(combined_payload["structures"]) > set( + observed_payload["structures"] + ) + + window = ClusterDynamicsMLMainWindow(initial_frames_dir=frames_dir) + window._on_run_finished(result) + + observed_panel = window.histogram_panel + mode_labels = [ + observed_panel.mode_combo.itemText(index) + for index in range(observed_panel.mode_combo.count()) + ] + assert mode_labels == [ + "Structure Fraction", + "Atom Fraction", + "Solvent Sort - Structure Fraction", + "Solvent Sort - Atom Fraction", + ] + assert observed_panel.secondary_combo.count() == 1 + assert observed_panel.secondary_combo.itemText(0) == "O" + + observed_panel.mode_combo.setCurrentIndex(2) + observed_axis = observed_panel.figure.axes[0] + assert observed_axis.get_title() == ( + "Solvent-Sort Structure Fraction Prior Histogram (O)" + ) + assert len(observed_axis.patches) > 0 + + combined_panel = window.histogram_panel + combined_panel.population_combo.setCurrentIndex(1) + combined_panel.mode_combo.setCurrentIndex(3) + combined_axis = combined_panel.figure.axes[0] + combined_tick_labels = { + tick.get_text() for tick in combined_axis.get_xticklabels() + } + + assert ( + combined_axis.get_title() + == "Solvent-Sort Atom Fraction Prior Histogram (O)" + ) + assert len(combined_axis.patches) > 0 + assert any("Pb" in label for label in combined_tick_labels) + window.close() + + +def test_clusterdynamicsml_window_appends_powerpoint_report_to_existing_project_report( + qapp, + tmp_path, + monkeypatch, +): + pytest.importorskip("pptx") + from pptx import Presentation + + del qapp + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + project_dir = _build_project_dir( + tmp_path, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + ) + existing_report = ( + build_project_paths(project_dir).reports_dir + / "existing_project_results.pptx" + ) + presentation = Presentation() + presentation.slides.add_slide(presentation.slide_layouts[0]) + presentation.save(str(existing_report)) + initial_slide_count = len(presentation.slides) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + project_dir=project_dir, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + window = ClusterDynamicsMLMainWindow( + initial_frames_dir=frames_dir, + initial_project_dir=project_dir, + ) + window._last_result = result + window.dynamics_plot_panel.set_result(result.dynamics_result) + window.surrogate_plot_panel.set_result(result) + window.run_panel.set_selection_summary( + window._format_preview_text(result.preview) + ) + window._populate_summary_box(result) + + captured_default_path: dict[str, str] = {} + + def fake_get_save_file_name(*args, **kwargs): + captured_default_path["value"] = str(args[2]) + return (str(existing_report), "PowerPoint Files (*.pptx)") + + monkeypatch.setattr( + "saxshell.clusterdynamicsml.ui.main_window.QFileDialog.getSaveFileName", + fake_get_save_file_name, + ) + + window.save_powerpoint_report() + + updated_presentation = Presentation(str(existing_report)) + + assert captured_default_path["value"] == str(existing_report) + assert len(updated_presentation.slides) > initial_slide_count + assert "ClusterDynamicsML Report" in _presentation_text( + updated_presentation + ) + assert "Predicted Larger Clusters" in _presentation_text( + updated_presentation + ) + window.close() + + +def test_saxshell_cli_forwards_to_clusterdynamicsml_subcommand(monkeypatch): + captured: dict[str, object] = {} + + def fake_clusterdynamicsml_main(argv=None): + captured["argv"] = argv + return 31 + + monkeypatch.setattr( + clusterdynamicsml_cli_module, + "main", + fake_clusterdynamicsml_main, + ) + + exit_code = saxshell_module.main( + [ + "clusterdynamicsml", + "--", + "frames", + "--clusters-dir", + "clusters", + "--experimental-data", + "exp.txt", + ] + ) + + assert exit_code == 31 + assert captured["argv"] == [ + "frames", + "--clusters-dir", + "clusters", + "--experimental-data", + "exp.txt", + ] From f02f56e8b6b789431bfc76a195a7b1c4249f1bf3 Mon Sep 17 00:00:00 2001 From: kewh5868 Date: Mon, 30 Mar 2026 11:31:21 -0600 Subject: [PATCH 2/3] Expand SAXS project setup, prefit, DREAM, and reporting --- .../default_beam_geometry_presets.json | 31 + .../template_likelihood_monosq.json | 2 +- .../_deprecated/template_likelihood_monosq.py | 44 +- .../template_pd_likelihood_monosq.json | 2 +- .../template_pd_likelihood_monosq.py | 36 +- ...mplate_pd_likelihood_monosq_decoupled.json | 2 +- ...template_pd_likelihood_monosq_decoupled.py | 36 +- .../template_pydream_poly_lma_hs_legacy.py | 5 +- .../template_pydream_monosq_normalized.json | 2 +- .../template_pydream_monosq_normalized.py | 36 +- .../template_pydream_poly_lma_hs.json | 2 +- .../template_pydream_poly_lma_hs.py | 120 +- ...mplate_pydream_poly_lma_hs_mix_approx.json | 2 +- ...template_pydream_poly_lma_hs_mix_approx.py | 78 +- .../saxs/_ui_assets/saxshell_icon.svg | 38 + src/saxshell/saxs/beam_geometry_presets.py | 210 + src/saxshell/saxs/dream/distributions.py | 13 +- src/saxshell/saxs/dream/results.py | 136 +- src/saxshell/saxs/dream/runtime.py | 18 + src/saxshell/saxs/model_report.py | 3032 +++++++++ src/saxshell/saxs/prefit/__init__.py | 4 + src/saxshell/saxs/prefit/workflow.py | 789 ++- src/saxshell/saxs/project_manager/__init__.py | 2 + .../saxs/project_manager/prior_plot.py | 39 +- src/saxshell/saxs/project_manager/project.py | 223 +- src/saxshell/saxs/solute_volume_fraction.py | 23 +- .../saxs/solution_scattering_estimator.py | 1270 ++++ src/saxshell/saxs/ui/branding.py | 202 + src/saxshell/saxs/ui/distribution_window.py | 1595 ++++- src/saxshell/saxs/ui/dream_tab.py | 548 +- src/saxshell/saxs/ui/main_window.py | 3009 ++++++++- src/saxshell/saxs/ui/prefit_tab.py | 968 ++- src/saxshell/saxs/ui/project_setup_tab.py | 346 +- .../saxs/ui/solute_volume_fraction_widget.py | 674 +- .../saxs/ui/solution_scattering_widget.py | 1354 ++++ src/saxshell/saxshell.py | 42 + tests/test_saxs_dream_runtime.py | 2 - tests/test_saxs_model_report.py | 214 + tests/test_saxs_prefit.py | 479 +- tests/test_saxs_template_installation.py | 2 +- tests/test_saxs_ui.py | 5836 ++++++++++++----- 41 files changed, 18426 insertions(+), 3040 deletions(-) create mode 100644 src/saxshell/saxs/_beam_geometry_presets/default_beam_geometry_presets.json create mode 100644 src/saxshell/saxs/_ui_assets/saxshell_icon.svg create mode 100644 src/saxshell/saxs/beam_geometry_presets.py create mode 100644 src/saxshell/saxs/model_report.py create mode 100644 src/saxshell/saxs/solution_scattering_estimator.py create mode 100644 src/saxshell/saxs/ui/branding.py create mode 100644 src/saxshell/saxs/ui/solution_scattering_widget.py create mode 100644 tests/test_saxs_model_report.py diff --git a/src/saxshell/saxs/_beam_geometry_presets/default_beam_geometry_presets.json b/src/saxshell/saxs/_beam_geometry_presets/default_beam_geometry_presets.json new file mode 100644 index 0000000..56c016f --- /dev/null +++ b/src/saxshell/saxs/_beam_geometry_presets/default_beam_geometry_presets.json @@ -0,0 +1,31 @@ +{ + "presets": { + "NSLS-II 28-ID-1 (default)": { + "incident_energy_kev": 74.0, + "capillary_size_mm": 1.0, + "capillary_geometry": "cylindrical", + "beam_profile": "uniform", + "beam_footprint_width_mm": 0.4, + "beam_footprint_height_mm": 0.4, + "notes": "Nominal beam size 0.4 mm x 0.4 mm." + }, + "APS 5-IDD (default - focused)": { + "incident_energy_kev": 17.0, + "capillary_size_mm": 1.0, + "capillary_geometry": "cylindrical", + "beam_profile": "uniform", + "beam_footprint_width_mm": 0.05, + "beam_footprint_height_mm": 1.0, + "notes": "Nominal beam size 50 micron x 1000 micron." + }, + "APS 5-IDD (default - unfocused)": { + "incident_energy_kev": 17.5, + "capillary_size_mm": 1.0, + "capillary_geometry": "cylindrical", + "beam_profile": "uniform", + "beam_footprint_width_mm": 1.0, + "beam_footprint_height_mm": 1.0, + "notes": "Nominal beam size 1 mm x 1 mm." + } + } +} diff --git a/src/saxshell/saxs/_model_templates/_deprecated/template_likelihood_monosq.json b/src/saxshell/saxs/_model_templates/_deprecated/template_likelihood_monosq.json index a3d8972..7ec369a 100644 --- a/src/saxshell/saxs/_model_templates/_deprecated/template_likelihood_monosq.json +++ b/src/saxshell/saxs/_model_templates/_deprecated/template_likelihood_monosq.json @@ -1,4 +1,4 @@ { "display_name": "MonoSQ Basic (archived)", - "description": "MonoSQ Basic\n\nStructure Factor:\nMonodisperse hard-sphere structure factor evaluated with the template's calc_monodisperse_sq function. The implementation uses the Percus-Yevick approximation, so eff_r controls the effective hard-sphere size and vol_frac controls the strength of interparticle packing correlations.\n\nForm Factor:\nWeighted mixture of MD-derived SAXS component profiles loaded from the project scattering_components directory. Each component trace represents an averaged scattering profile for one recognized cluster class, and the component weights are combined before the structure-factor correction is applied.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile to the total model intensity.\nsolv_w: Solvent-trace mixing weight applied before the final scale and offset terms.\noffset: Additive baseline shift used to account for residual background or instrument offsets.\neff_r: Effective hard-sphere radius used inside the structure-factor calculation.\nvol_frac: Hard-sphere volume fraction that sets the correlation strength in the Percus-Yevick structure factor.\nscale: Overall multiplicative scale factor applied to the assembled intensity profile." + "description": "MonoSQ Basic\n\nStructure Factor:\nMonodisperse hard-sphere structure factor evaluated with the template's calc_monodisperse_sq function. The implementation uses the Percus-Yevick approximation, so eff_r controls the effective hard-sphere size and vol_frac controls the strength of interparticle packing correlations.\n\nForm Factor:\nWeighted mixture of MD-derived SAXS component profiles loaded from the project scattering_components directory. Each component trace represents an averaged scattering profile for one recognized cluster class, and the component weights are combined before the structure-factor correction is applied.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile to the total solute intensity.\nsolv_w: Solvent weight applied directly to the experimental solvent trace and constrained to the interval [0, 1].\noffset: Additive baseline shift used to account for residual background or instrument offsets.\neff_r: Effective hard-sphere radius used inside the structure-factor calculation.\nvol_frac: Hard-sphere volume fraction that sets the correlation strength in the Percus-Yevick structure factor.\nscale: Multiplicative scale factor applied only to the solute scattering contribution." } diff --git a/src/saxshell/saxs/_model_templates/_deprecated/template_likelihood_monosq.py b/src/saxshell/saxs/_model_templates/_deprecated/template_likelihood_monosq.py index 2397b94..94fd80e 100644 --- a/src/saxshell/saxs/_model_templates/_deprecated/template_likelihood_monosq.py +++ b/src/saxshell/saxs/_model_templates/_deprecated/template_likelihood_monosq.py @@ -40,6 +40,20 @@ def calc_monodisperse_sq(r, vol_frac, q_values): return np.asarray(sqs) +def _bounded_solvent_weight(value): + return float(np.clip(float(value), 0.0, 1.0)) + + +def structure_factor_profile(q, solvent_data, model_data, **params): + """Return the pure hard-sphere structure-factor trace S(q).""" + del solvent_data, model_data + return calc_monodisperse_sq( + params["eff_r"], + params["vol_frac"], + np.asarray(q, dtype=float), + ) + + def lmfit_model_profile(q, solvent_data, model_data, **params): """Evaluate the lmfit SAXS profile model.""" weight_keys = sorted( @@ -47,7 +61,7 @@ def lmfit_model_profile(q, solvent_data, model_data, **params): key=lambda key: int(key.lstrip("w")), ) weights = [params[key] for key in weight_keys] - solv_w = params["solv_w"] + solv_w = _bounded_solvent_weight(params["solv_w"]) offset = params["offset"] eff_r = params["eff_r"] vol_frac = params["vol_frac"] @@ -57,15 +71,15 @@ def lmfit_model_profile(q, solvent_data, model_data, **params): for weight, component in zip(weights, model_data): mixture += weight * component - iq = mixture * calc_monodisperse_sq(eff_r, vol_frac, q) - iq += solv_w * solvent_data - return iq * scale + offset + solute_intensity = mixture * calc_monodisperse_sq(eff_r, vol_frac, q) + solvent_contribution = solv_w * np.asarray(solvent_data, dtype=float) + return scale * solute_intensity + solvent_contribution + offset def log_likelihood_monosq(params): """Compute the legacy monodisperse SAXS log-likelihood.""" weights = params[:-5] - solv_w = params[-5] + solv_w = _bounded_solvent_weight(params[-5]) offset = params[-4] eff_r = params[-3] vol_frac = params[-2] @@ -82,11 +96,13 @@ def log_likelihood_monosq(params): for index, weight in enumerate(weights): mixture_intensity += weight * theoretical_intensities[index] - model_intensity = mixture_intensity * calc_monodisperse_sq( + solute_intensity = mixture_intensity * calc_monodisperse_sq( eff_r, vol_frac, q_values ) - model_intensity += solv_w * solvent_intensities - model_intensity = model_intensity * scale + offset + solvent_contribution = solv_w * np.asarray( + solvent_intensities, dtype=float + ) + model_intensity = scale * solute_intensity + solvent_contribution + offset return np.sum( norm.logpdf( @@ -100,7 +116,7 @@ def log_likelihood_monosq(params): def compute_model_profile(params): """Return the model intensity and q grid for a parameter vector.""" weights = params[:-5] - solv_w = params[-5] + solv_w = _bounded_solvent_weight(params[-5]) offset = params[-4] eff_r = params[-3] vol_frac = params[-2] @@ -116,9 +132,11 @@ def compute_model_profile(params): for index, weight in enumerate(weights): mixture_intensity += weight * theoretical_intensities[index] - iq_intensity = mixture_intensity * calc_monodisperse_sq( + solute_intensity = mixture_intensity * calc_monodisperse_sq( eff_r, vol_frac, q_values ) - iq_intensity += solv_w * solvent_intensities - iq_intensity = iq_intensity * scale + offset - return iq_intensity, q_values + solvent_contribution = solv_w * np.asarray( + solvent_intensities, dtype=float + ) + model_intensity = scale * solute_intensity + solvent_contribution + offset + return model_intensity, q_values diff --git a/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq.json b/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq.json index ced7e50..4291f34 100644 --- a/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq.json +++ b/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq.json @@ -1,4 +1,4 @@ { "display_name": "MonoSQ PD (archived)", - "description": "MonoSQ PD\n\nStructure Factor:\nMonodisperse hard-sphere structure factor evaluated with calc_monodisperse_sq using the Percus-Yevick approximation. The hard-sphere correlation term is controlled by eff_r for the characteristic particle radius and vol_frac for the effective packing fraction.\n\nForm Factor:\nWeighted mixture of MD-derived SAXS component profiles constructed from the project's averaged cluster scattering curves. The component traces are linearly combined and then modulated by the hard-sphere structure factor to produce the model intensity.\n\nLikelihood Form:\nThis template keeps the same forward scattering expression as the monodisperse hard-sphere model, but the probability-density evaluation is written for posterior-density workflows and returns a point-normalized log-likelihood for downstream DREAM-style refinement.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile.\nsolv_w: Solvent intensity mixing weight added before the global scale and offset.\noffset: Additive baseline shift applied after the scattering profile is assembled.\neff_r: Effective hard-sphere radius used in the structure-factor term.\nvol_frac: Effective hard-sphere volume fraction controlling interparticle correlations.\nscale: Overall multiplicative intensity scale used to match the experimental SAXS amplitude." + "description": "MonoSQ PD\n\nStructure Factor:\nMonodisperse hard-sphere structure factor evaluated with calc_monodisperse_sq using the Percus-Yevick approximation. The hard-sphere correlation term is controlled by eff_r for the characteristic particle radius and vol_frac for the effective packing fraction.\n\nForm Factor:\nWeighted mixture of MD-derived SAXS component profiles constructed from the project's averaged cluster scattering curves. The component traces are linearly combined and then modulated by the hard-sphere structure factor to produce the solute contribution, which is scaled separately from the solvent trace.\n\nLikelihood Form:\nThis template keeps the same forward scattering expression as the monodisperse hard-sphere model, but the probability-density evaluation is written for posterior-density workflows and returns a point-normalized log-likelihood for downstream DREAM-style refinement.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile.\nsolv_w: Solvent weight applied directly to the experimental solvent trace and constrained to the interval [0, 1].\noffset: Additive baseline shift applied after the scattering profile is assembled.\neff_r: Effective hard-sphere radius used in the structure-factor term.\nvol_frac: Effective hard-sphere volume fraction controlling interparticle correlations.\nscale: Multiplicative intensity scale used only for the solute scattering contribution." } diff --git a/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq.py b/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq.py index f3c7827..897a9f3 100644 --- a/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq.py +++ b/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq.py @@ -40,6 +40,20 @@ def calc_monodisperse_sq(r, vol_frac, q_values): return np.asarray(sqs) +def _bounded_solvent_weight(value): + return float(np.clip(float(value), 0.0, 1.0)) + + +def structure_factor_profile(q, solvent_data, model_data, **params): + """Return the pure hard-sphere structure-factor trace S(q).""" + del solvent_data, model_data + return calc_monodisperse_sq( + params["eff_r"], + params["vol_frac"], + np.asarray(q, dtype=float), + ) + + def lmfit_model_profile(q, solvent_data, model_data, **params): """Evaluate the lmfit SAXS profile model.""" weight_keys = sorted( @@ -47,7 +61,7 @@ def lmfit_model_profile(q, solvent_data, model_data, **params): key=lambda key: int(key.lstrip("w")), ) weights = [params[key] for key in weight_keys] - solv_w = params["solv_w"] + solv_w = _bounded_solvent_weight(params["solv_w"]) offset = params["offset"] eff_r = params["eff_r"] vol_frac = params["vol_frac"] @@ -57,9 +71,9 @@ def lmfit_model_profile(q, solvent_data, model_data, **params): for weight, component in zip(weights, model_data): mixture += weight * component - iq = mixture * calc_monodisperse_sq(eff_r, vol_frac, q) - iq += solv_w * solvent_data - return iq * scale + offset + solute_intensity = mixture * calc_monodisperse_sq(eff_r, vol_frac, q) + solvent_contribution = solv_w * np.asarray(solvent_data, dtype=float) + return scale * solute_intensity + solvent_contribution + offset def log_likelihood_monosq(params): @@ -69,7 +83,7 @@ def log_likelihood_monosq(params): n_profiles = len(theoretical_intensities) weights = params[:n_profiles] - solv_w = params[n_profiles] + solv_w = _bounded_solvent_weight(params[n_profiles]) offset = params[n_profiles + 1] eff_r = params[n_profiles + 2] vol_frac = params[n_profiles + 3] @@ -79,9 +93,15 @@ def log_likelihood_monosq(params): for weight, component in zip(weights, theoretical_intensities): mixture += weight * component - iq = mixture * calc_monodisperse_sq(eff_r, vol_frac, q_values) - iq += solv_w * solvent_intensities - model_intensity = iq * scale + offset + solute_intensity = mixture * calc_monodisperse_sq( + eff_r, + vol_frac, + q_values, + ) + solvent_contribution = solv_w * np.asarray( + solvent_intensities, dtype=float + ) + model_intensity = scale * solute_intensity + solvent_contribution + offset n_points = len(experimental_intensities) log_likelihood = np.sum( diff --git a/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq_decoupled.json b/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq_decoupled.json index 494ec8e..41ff4a9 100644 --- a/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq_decoupled.json +++ b/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq_decoupled.json @@ -1,4 +1,4 @@ { "display_name": "MonoSQ Decoupled (archived)", - "description": "MonoSQ Decoupled\n\nStructure Factor:\nMonodisperse hard-sphere structure factor evaluated with calc_monodisperse_sq using the Percus-Yevick approximation. The effective particle size is controlled by eff_r and the interparticle correlation strength is controlled by vol_frac.\n\nForm Factor:\nWeighted mixture of MD-derived SAXS component profiles assembled from the project's averaged cluster scattering curves. The weighted component mixture is multiplied by the hard-sphere structure factor and then combined with solvent and baseline terms.\n\nModel Organization:\nThis template separates the forward model into a dedicated model_monosq helper before evaluating the likelihood. That decoupled structure makes the intensity expression easier to reuse in runtime scripts while preserving the same physical scattering model used by the prefit calculation.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile.\nsolv_w: Solvent intensity mixing weight applied before the final scale and offset.\noffset: Additive baseline term used to capture residual background intensity.\neff_r: Effective hard-sphere radius used in the structure-factor calculation.\nvol_frac: Hard-sphere volume fraction that controls the Percus-Yevick correlation term.\nscale: Overall multiplicative scale factor applied to the final model intensity." + "description": "MonoSQ Decoupled\n\nStructure Factor:\nMonodisperse hard-sphere structure factor evaluated with calc_monodisperse_sq using the Percus-Yevick approximation. The effective particle size is controlled by eff_r and the interparticle correlation strength is controlled by vol_frac.\n\nForm Factor:\nWeighted mixture of MD-derived SAXS component profiles assembled from the project's averaged cluster scattering curves. The weighted component mixture is multiplied by the hard-sphere structure factor, scaled by the solute scale, and then combined with a separately weighted solvent trace and baseline term.\n\nModel Organization:\nThis template separates the forward model into a dedicated model_monosq helper before evaluating the likelihood. That decoupled structure makes the intensity expression easier to reuse in runtime scripts while preserving the same physical scattering model used by the prefit calculation.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile.\nsolv_w: Solvent weight applied directly to the experimental solvent trace and constrained to the interval [0, 1].\noffset: Additive baseline term used to capture residual background intensity.\neff_r: Effective hard-sphere radius used in the structure-factor calculation.\nvol_frac: Hard-sphere volume fraction that controls the Percus-Yevick correlation term.\nscale: Multiplicative scale factor applied only to the solute scattering contribution." } diff --git a/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq_decoupled.py b/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq_decoupled.py index 88a4941..00ebb68 100644 --- a/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq_decoupled.py +++ b/src/saxshell/saxs/_model_templates/_deprecated/template_pd_likelihood_monosq_decoupled.py @@ -40,6 +40,20 @@ def calc_monodisperse_sq(r, vol_frac, q_values): return np.asarray(sqs) +def _bounded_solvent_weight(value): + return float(np.clip(float(value), 0.0, 1.0)) + + +def structure_factor_profile(q, solvent_data, model_data, **params): + """Return the pure hard-sphere structure-factor trace S(q).""" + del solvent_data, model_data + return calc_monodisperse_sq( + params["eff_r"], + params["vol_frac"], + np.asarray(q, dtype=float), + ) + + def lmfit_model_profile(q, solvent_data, model_data, **params): """Evaluate the lmfit SAXS profile model.""" weight_keys = sorted( @@ -47,7 +61,7 @@ def lmfit_model_profile(q, solvent_data, model_data, **params): key=lambda key: int(key.lstrip("w")), ) weights = [params[key] for key in weight_keys] - solv_w = params["solv_w"] + solv_w = _bounded_solvent_weight(params["solv_w"]) offset = params["offset"] eff_r = params["eff_r"] vol_frac = params["vol_frac"] @@ -57,9 +71,9 @@ def lmfit_model_profile(q, solvent_data, model_data, **params): for weight, component in zip(weights, model_data): mixture += weight * component - iq = mixture * calc_monodisperse_sq(eff_r, vol_frac, q) - iq += solv_w * solvent_data - return iq * scale + offset + solute_intensity = mixture * calc_monodisperse_sq(eff_r, vol_frac, q) + solvent_contribution = solv_w * np.asarray(solvent_data, dtype=float) + return scale * solute_intensity + solvent_contribution + offset def model_monosq(params): @@ -68,7 +82,7 @@ def model_monosq(params): n_profiles = len(theoretical_intensities) weights = params[:n_profiles] - solv_w = params[n_profiles] + solv_w = _bounded_solvent_weight(params[n_profiles]) offset = params[n_profiles + 1] eff_r = params[n_profiles + 2] vol_frac = params[n_profiles + 3] @@ -78,9 +92,15 @@ def model_monosq(params): for weight, component in zip(weights, theoretical_intensities): mixture += weight * component - iq = mixture * calc_monodisperse_sq(eff_r, vol_frac, q_values) - iq += solv_w * solvent_intensities - return iq * scale + offset + solute_intensity = mixture * calc_monodisperse_sq( + eff_r, + vol_frac, + q_values, + ) + solvent_contribution = solv_w * np.asarray( + solvent_intensities, dtype=float + ) + return scale * solute_intensity + solvent_contribution + offset def log_likelihood_monosq(params): diff --git a/src/saxshell/saxs/_model_templates/_deprecated/template_pydream_poly_lma_hs_legacy.py b/src/saxshell/saxs/_model_templates/_deprecated/template_pydream_poly_lma_hs_legacy.py index ea15409..53886b4 100644 --- a/src/saxshell/saxs/_model_templates/_deprecated/template_pydream_poly_lma_hs_legacy.py +++ b/src/saxshell/saxs/_model_templates/_deprecated/template_pydream_poly_lma_hs_legacy.py @@ -8,9 +8,9 @@ # inputs_pydream: q_values, experimental_intensities, solvent_intensities, theoretical_intensities, effective_radii, params # param_columns: Structure, Motif, Param, Value, Vary, Min, Max # cluster_geometry_metadata: true -# param: phi_solute,0.02,True,0.0,0.5 +# param: phi_solute,0.02,False,0.0,0.5 # param: phi_int,0.02,True,0.0,0.4 -# param: solvent_scale,1.0,True,0.0,5.0 +# param: solvent_scale,1.0,False,0.0,1.0 # param: scale,1.0,True,1e-8,1e8 # param: offset,0.0,True,-1e6,1e6 # param: log_sigma,-9.21,True,-20.0,5.0 @@ -19,6 +19,7 @@ normalize_profile_fractions = _mixed.normalize_profile_fractions equivalent_volume_radius = _mixed.equivalent_volume_radius polydisperse_lma_hs_model = _mixed.polydisperse_lma_hs_model +structure_factor_profile = _mixed.structure_factor_profile def lmfit_model_profile( diff --git a/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized.json b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized.json index 224bd58..4b86def 100644 --- a/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized.json +++ b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized.json @@ -1,4 +1,4 @@ { "display_name": "pyDREAM MonoSQ Normalized", - "description": "pyDREAM MonoSQ Normalized\n\nPurpose:\nCanonical pyDREAM likelihood template for monodisperse hard-sphere SAXS fitting when the fitted q-range, interpolation density, or number of experimental points may vary between runs.\n\nStructure Factor:\nMonodisperse hard-sphere structure factor evaluated with calc_monodisperse_sq using the Percus-Yevick approximation. The effective particle size is controlled by eff_r and the interparticle correlation strength is controlled by vol_frac.\n\nForm Factor:\nWeighted mixture of MD-derived SAXS component profiles assembled from the project's averaged cluster scattering curves. The weighted component mixture is multiplied by the hard-sphere structure factor and then combined with solvent and baseline terms.\n\nLikelihood Convention:\nThe pyDREAM likelihood is normalized by the number of experimental points. This keeps the log-likelihood on a per-point basis so that its magnitude is less sensitive to changes in fitted q-range, resampling density, or dataset length.\n\nImplementation Notes:\nThis template uses explicit n_profiles slicing for parameter unpacking and separates the forward model into a dedicated model_monosq helper before evaluating the likelihood.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile.\nsolv_w: Solvent intensity mixing weight applied before the final scale and offset.\noffset: Additive baseline term used to capture residual background intensity.\neff_r: Effective hard-sphere radius used in the structure-factor calculation.\nvol_frac: Hard-sphere volume fraction that controls the Percus-Yevick correlation term.\nscale: Overall multiplicative scale factor applied to the final model intensity." + "description": "pyDREAM MonoSQ Normalized\n\nPurpose:\nCanonical pyDREAM likelihood template for monodisperse hard-sphere SAXS fitting when the fitted q-range, interpolation density, or number of experimental points may vary between runs.\n\nStructure Factor:\nMonodisperse hard-sphere structure factor evaluated with calc_monodisperse_sq using the Percus-Yevick approximation. The effective particle size is controlled by eff_r and the interparticle correlation strength is controlled by vol_frac.\n\nForm Factor:\nWeighted mixture of MD-derived SAXS component profiles assembled from the project's averaged cluster scattering curves. The weighted component mixture is multiplied by the hard-sphere structure factor, scaled by the global solute scale, and then combined with a separately weighted solvent trace and baseline term.\n\nLikelihood Convention:\nThe pyDREAM likelihood is normalized by the number of experimental points. This keeps the log-likelihood on a per-point basis so that its magnitude is less sensitive to changes in fitted q-range, resampling density, or dataset length.\n\nImplementation Notes:\nThis template uses explicit n_profiles slicing for parameter unpacking and separates the forward model into a dedicated model_monosq helper before evaluating the likelihood.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile.\nsolv_w: Solvent weight applied directly to the experimental solvent trace and constrained to the interval [0, 1].\noffset: Additive baseline term used to capture residual background intensity.\neff_r: Effective hard-sphere radius used in the structure-factor calculation.\nvol_frac: Hard-sphere volume fraction that controls the Percus-Yevick correlation term.\nscale: Multiplicative scale factor applied only to the solute scattering contribution." } diff --git a/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized.py b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized.py index ba604ef..8e669ac 100644 --- a/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized.py +++ b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized.py @@ -48,6 +48,20 @@ def calc_monodisperse_sq(r, vol_frac, q_values): return np.asarray(sqs) +def _bounded_solvent_weight(value): + return float(np.clip(float(value), 0.0, 1.0)) + + +def structure_factor_profile(q, solvent_data, model_data, **params): + """Return the pure hard-sphere structure-factor trace S(q).""" + del solvent_data, model_data + return calc_monodisperse_sq( + params["eff_r"], + params["vol_frac"], + np.asarray(q, dtype=float), + ) + + def lmfit_model_profile(q, solvent_data, model_data, **params): """Evaluate the monodisperse SAXS model for lmfit.""" weight_keys = sorted( @@ -56,7 +70,7 @@ def lmfit_model_profile(q, solvent_data, model_data, **params): ) weights = [params[key] for key in weight_keys] - solv_w = params["solv_w"] + solv_w = _bounded_solvent_weight(params["solv_w"]) offset = params["offset"] eff_r = params["eff_r"] vol_frac = params["vol_frac"] @@ -66,10 +80,10 @@ def lmfit_model_profile(q, solvent_data, model_data, **params): for weight, component in zip(weights, model_data): mixture += weight * component - iq = mixture * calc_monodisperse_sq(eff_r, vol_frac, q) - iq += solv_w * solvent_data + solute_intensity = mixture * calc_monodisperse_sq(eff_r, vol_frac, q) + solvent_contribution = solv_w * np.asarray(solvent_data, dtype=float) - return iq * scale + offset + return scale * solute_intensity + solvent_contribution + offset def model_monosq(params): @@ -81,7 +95,7 @@ def model_monosq(params): n_profiles = len(theoretical_intensities) weights = params[:n_profiles] - solv_w = params[n_profiles] + solv_w = _bounded_solvent_weight(params[n_profiles]) offset = params[n_profiles + 1] eff_r = params[n_profiles + 2] vol_frac = params[n_profiles + 3] @@ -91,10 +105,16 @@ def model_monosq(params): for weight, component in zip(weights, theoretical_intensities): mixture += weight * component - iq = mixture * calc_monodisperse_sq(eff_r, vol_frac, q_values) - iq += solv_w * solvent_intensities + solute_intensity = mixture * calc_monodisperse_sq( + eff_r, + vol_frac, + q_values, + ) + solvent_contribution = solv_w * np.asarray( + solvent_intensities, dtype=float + ) - return iq * scale + offset + return scale * solute_intensity + solvent_contribution + offset def log_likelihood_monosq(params): diff --git a/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs.json b/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs.json index fd5e2dc..5c315db 100644 --- a/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs.json +++ b/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs.json @@ -1,6 +1,6 @@ { "display_name": "pyDREAM Poly LMA Hard-Sphere", - "description": "pyDREAM Poly LMA Hard-Sphere\n\nStructure Factor:\nCluster-resolved hard-sphere Percus-Yevick structure factors are evaluated with one effective interaction radius per cluster profile and a shared interaction packing parameter phi_int.\n\nScientific Scope:\nThis strict template is the sphere-only hard-sphere workflow. It keeps the literature-consistent use case where the interaction model remains a mixture of hard spheres rather than an approximate mixed hard-sphere/ellipsoid route.\n\nForm Factor:\nThe solute form factor is a weighted mixture of the project's averaged cluster scattering profiles. SAXSShell generates weight parameters w0 ... wN-1 from the project component order, then this template normalizes those raw weights internally before applying the per-cluster hard-sphere structure factor.\n\nModel Organization:\nThis template is intended to run through the production Prefit and DREAM workflows, not as a stand-alone example. The Prefit workflow supplies the generated component weights, solvent trace, averaged cluster intensities, and the computed cluster-geometry metadata array effective_radii. It generates one sphere radius parameter r_eff_wN per mapped component.\n\nModel Equation:\nI_model(q) = scale * [ phi_solute * sum_i x_i I_i(q) S_HS(q; R_eff_i, phi_int) + solvent_scale * (1 - phi_solute) * I_solv(q) ] + offset\n\nInternal Abundance Normalization:\nx_i = w_i / sum_j w_j\nw_i >= 0 for all i\nsum_i x_i = 1 after normalization\n\nModel Parameters:\nw0 ... wN-1: Generated cluster-abundance coefficients aligned to the project component rows.\nr_eff_wN: Generated per-component effective-radius parameter used for the sphere-only hard-sphere interaction model.\nphi_solute: Solute volume fraction used to scale the cluster contribution relative to the solvent.\nphi_int: Effective packing fraction used only inside the hard-sphere interaction model.\nsolvent_scale: Solvent trace multiplier applied before the final scale and offset.\nscale: Global multiplicative scale factor applied to the combined model.\noffset: Additive q-independent background term.\nlog_sigma: Natural logarithm of the Gaussian noise standard deviation used by the DREAM likelihood.\n\nCluster Geometry Metadata:\nThis template supports per-cluster geometry metadata, but it only enables sphere approximations in Prefit. Geometry rows that recommend ellipsoids are normalized back to sphere radii for this strict hard-sphere workflow.\n", + "description": "pyDREAM Poly LMA Hard-Sphere\n\nStructure Factor:\nCluster-resolved hard-sphere Percus-Yevick structure factors are evaluated with one effective interaction radius per cluster profile and a shared interaction packing parameter phi_int.\n\nScientific Scope:\nThis strict template is the sphere-only hard-sphere workflow. It keeps the literature-consistent use case where the interaction model remains a mixture of hard spheres rather than an approximate mixed hard-sphere/ellipsoid route.\n\nForm Factor:\nThe solute form factor is a weighted mixture of the project's averaged cluster scattering profiles. SAXSShell generates weight parameters w0 ... wN-1 from the project component order, then this template normalizes those raw weights internally before applying the per-cluster hard-sphere structure factor.\n\nModel Organization:\nThis template is intended to run through the production Prefit and DREAM workflows, not as a stand-alone example. The Prefit workflow supplies the generated component weights, solvent trace, averaged cluster intensities, and the computed cluster-geometry metadata array effective_radii. It generates one sphere radius parameter r_eff_wN per mapped component.\n\nModel Equation:\nI_model(q) = scale * phi_solute * sum_i x_i I_i(q) S_HS(q; R_eff_i, phi_int) + solvent_scale * (1 - phi_solute) * I_solv(q) + offset\n\nInternal Abundance Normalization:\nx_i = w_i / sum_j w_j\nw_i >= 0 for all i\nsum_i x_i = 1 after normalization\n\nModel Parameters:\nw0 ... wN-1: Generated cluster-abundance coefficients aligned to the project component rows.\nr_eff_wN: Generated per-component effective-radius parameter used for the sphere-only hard-sphere interaction model.\nphi_solute: SAXS-effective solute interaction ratio used in the solvent-complement term and fixed by default so the contrast-weighted solvent subtraction can be supplied from the solution-scattering estimator rather than co-fit with solvent_scale.\nphi_int: Effective packing fraction used only inside the hard-sphere interaction model.\nsolvent_scale: Solvent attenuation/subtraction weight applied to the solvent contribution, constrained to the interval [0, 1], and fixed by default so it does not become redundant with phi_solute.\nscale: Global multiplicative scale factor applied only to the solute contribution.\noffset: Additive q-independent background term.\nlog_sigma: Natural logarithm of the Gaussian noise standard deviation used by the DREAM likelihood.\n\nFitting Guidance:\nUse phi_solute and solvent_scale as prior-informed solvent-subtraction controls. They should normally be populated from the physical-volume, SAXS-effective interaction, and attenuation estimators, then kept fixed while the fit isolates the solute scattering parameters.\n\nCluster Geometry Metadata:\nThis template supports per-cluster geometry metadata, but it only enables sphere approximations in Prefit. Geometry rows that recommend ellipsoids are normalized back to sphere radii for this strict hard-sphere workflow.\n", "capabilities": { "cluster_geometry_metadata": { "supported": true, diff --git a/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs.py b/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs.py index 1e412b7..57ddb0d 100644 --- a/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs.py +++ b/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs.py @@ -18,9 +18,9 @@ # offset, # log_sigma # -# param: phi_solute,0.02,True,0.0,0.5 +# param: phi_solute,0.02,False,0.0,0.5 # param: phi_int,0.02,True,0.0,0.4 -# param: solvent_scale,1.0,True,0.0,5.0 +# param: solvent_scale,1.0,False,0.0,1.0 # param: scale,1.0,True,1e-8,1e8 # param: offset,0.0,True,-1e6,1e6 # param: log_sigma,-9.21,True,-20.0,5.0 @@ -31,13 +31,13 @@ # - decoupled forward model helper # - discrete local-monodisperse style cluster sum # - per-cluster hard-sphere S(Q) using effective radii -# - explicit solvent template, global scale, and offset +# - explicit solvent template, bounded solvent weight, and offset # # Model equation: -# I_model(q) = scale * [ -# phi_solute * sum_i x_i I_i(q) S_HS(q; R_eff_i, phi_int) +# I_model(q) = +# scale * phi_solute * sum_i x_i I_i(q) S_HS(q; R_eff_i, phi_int) # + solvent_scale * (1 - phi_solute) * I_solv(q) -# ] + offset +# + offset # # Internal abundance normalization: # x_i = f_i / sum_j f_j @@ -55,11 +55,13 @@ # abundances instead of redundant global scale factors. # # phi_solute -# Physical solute volume fraction in the measured solution. This term -# scales the cluster contribution relative to the solvent template. -# Good prior information can come from solution density, composition, -# and solute/solvent molar masses. Keep this fixed or tightly bounded -# unless the data are on a credible absolute scale. +# SAXS-effective solute interaction ratio in the measured solution. +# This is the contrast-weighted model-facing solute fraction, not the +# raw bulk-density volume fraction. Good prior information can come +# from the solution-scattering estimator, which reports both the +# physical occupancy fraction and the energy-dependent SAXS-effective +# interaction ratio. Keep this fixed or tightly bounded because the +# solvent subtraction already depends on (1 - phi_solute). # # phi_int # Effective structural volume fraction used only inside the hard- @@ -70,16 +72,20 @@ # loading for irregular or anisotropic clusters. # # solvent_scale -# Multiplicative coefficient applied to the experimental solvent SAXS -# template before the global scale and offset. This absorbs mismatch -# from transmission, thickness, normalization, or imperfect solvent -# subtraction. +# Bounded 0..1 attenuation/subtraction weight applied to the +# experimental solvent SAXS template contribution. In the split- +# fraction poly templates the solvent complement remains in +# (1 - phi_solute), so solvent_scale carries the attenuation and +# normalization part of the solvent-background correction without +# duplicating the contrast-weighted interaction ratio. Keep this fixed +# by default so it does not become redundant with phi_solute during +# fitting. # # scale -# Global multiplicative intensity factor applied to the full model. -# Keep this free when the measured SAXS data are not on absolute -# intensity scale or when the cluster I(Q) library is only known up to -# an arbitrary normalization. +# Global multiplicative intensity factor applied only to the solute +# scattering contribution. Keep this free when the measured SAXS data +# are not on absolute intensity scale or when the cluster I(Q) +# library is only known up to an arbitrary normalization. # # offset # Constant additive background term. Use this to absorb fluorescence @@ -259,6 +265,44 @@ def _full_params_to_param_dict(full_params): } +def _bounded_solvent_weight(value): + return float(np.clip(float(value), 0.0, 1.0)) + + +def _effective_structure_factor_profile( + q_values, + cluster_intensities, + effective_radii, + raw_weights, + phi_int, +): + """Return the mixture-equivalent S(q) that modulates the form + factor.""" + q_values = np.asarray(q_values, dtype=float) + effective_radii = np.asarray(effective_radii, dtype=float) + raw_weights = np.asarray(raw_weights, dtype=float) + fractions = normalize_profile_fractions(raw_weights) + numerator = np.zeros_like(q_values, dtype=float) + denominator = np.zeros_like(q_values, dtype=float) + fallback = np.zeros_like(q_values, dtype=float) + + for frac, iq_cluster, radius in zip( + fractions, cluster_intensities, effective_radii + ): + iq_cluster = np.asarray(iq_cluster, dtype=float) + sq = calc_hardsphere_sq(radius, phi_int, q_values) + numerator += frac * iq_cluster * sq + denominator += frac * iq_cluster + fallback += frac * sq + + structure_factor = fallback.copy() + valid_mask = np.abs(denominator) > 1e-12 + structure_factor[valid_mask] = ( + numerator[valid_mask] / denominator[valid_mask] + ) + return structure_factor + + def polydisperse_lma_hs_model( q_values, cluster_intensities, @@ -301,10 +345,15 @@ def polydisperse_lma_hs_model( sq = calc_hardsphere_sq(radius, phi_int, q_values) solute_sum += frac * iq_cluster * sq - model = phi_solute * solute_sum - model += solvent_scale * (1.0 - phi_solute) * solvent_intensities + solvent_weight = _bounded_solvent_weight(solvent_scale) + solute_contribution = scale * phi_solute * solute_sum + solvent_contribution = ( + solvent_weight + * (1.0 - phi_solute) + * np.asarray(solvent_intensities, dtype=float) + ) - return scale * model + offset + return solute_contribution + solvent_contribution + offset def lmfit_model_profile( @@ -321,7 +370,7 @@ def lmfit_model_profile( phi_solute = params["phi_solute"] phi_int = params["phi_int"] - solvent_scale = params["solvent_scale"] + solvent_scale = _bounded_solvent_weight(params["solvent_scale"]) scale = params["scale"] offset = params["offset"] @@ -339,6 +388,27 @@ def lmfit_model_profile( ) +def structure_factor_profile( + q, solvent_data, model_data, effective_radii, **params +): + """Return the mixture-equivalent structure-factor trace S(q).""" + del solvent_data + weight_keys = _weight_keys_from_params(params) + raw_weights = np.asarray([params[key] for key in weight_keys], dtype=float) + resolved_effective_radii = _resolve_effective_radii( + weight_keys, + params, + effective_radii, + ) + return _effective_structure_factor_profile( + q, + model_data, + resolved_effective_radii, + raw_weights, + params["phi_int"], + ) + + def model_poly_lma_hs(params): """Return the forward model intensity for pyDREAM.""" global q_values @@ -361,7 +431,7 @@ def model_poly_lma_hs(params): ) phi_solute = named_params["phi_solute"] phi_int = named_params["phi_int"] - solvent_scale = named_params["solvent_scale"] + solvent_scale = _bounded_solvent_weight(named_params["solvent_scale"]) scale = named_params["scale"] offset = named_params["offset"] else: @@ -370,7 +440,7 @@ def model_poly_lma_hs(params): resolved_effective_radii = np.asarray(effective_radii, dtype=float) phi_solute = params[n_profiles] phi_int = params[n_profiles + 1] - solvent_scale = params[n_profiles + 2] + solvent_scale = _bounded_solvent_weight(params[n_profiles + 2]) scale = params[n_profiles + 3] offset = params[n_profiles + 4] diff --git a/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_mix_approx.json b/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_mix_approx.json index ec74ae2..f182cce 100644 --- a/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_mix_approx.json +++ b/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_mix_approx.json @@ -1,6 +1,6 @@ { "display_name": "pyDREAM Poly LMA Hard-Sphere/Ellipsoid Mix (Approx.)", - "description": "pyDREAM Poly LMA Hard-Sphere/Ellipsoid Mix (Approx.)\n\nStructure Factor:\nCluster-resolved hard-sphere Percus-Yevick structure factors are evaluated with one effective interaction radius per cluster profile and a shared interaction packing parameter phi_int.\n\nScientific Scope:\nThis template is an approximate mixed-shape workflow for projects that want to keep both sphere and ellipsoid geometry options active in Prefit. Ellipsoid rows are mapped onto an equivalent-sphere interaction radius before the hard-sphere Percus-Yevick structure factor is evaluated. That makes this a Hansen-style approximate hard-body-to-sphere route rather than an exact hard-ellipsoid Percus-Yevick closure.\n\nForm Factor:\nThe solute form factor is a weighted mixture of the project's averaged cluster scattering profiles. SAXSShell generates weight parameters w0 ... wN-1 from the project component order, then this template normalizes those raw weights internally before applying the per-cluster hard-sphere structure factor.\n\nModel Organization:\nThis template is intended to run through the production Prefit and DREAM workflows, not as a stand-alone example. The Prefit workflow supplies the generated component weights, solvent trace, averaged cluster intensities, and the computed cluster-geometry metadata array effective_radii. It also generates cluster-geometry parameter rows that follow the active structure-factor approximation for each mapped component: sphere rows expose one fitted radius parameter, while ellipsoid rows expose three semiaxis parameters whose equivalent-volume sphere radius is used in S(Q).\n\nModel Equation:\nI_model(q) = scale * [ phi_solute * sum_i x_i I_i(q) S_HS(q; R_eff_i, phi_int) + solvent_scale * (1 - phi_solute) * I_solv(q) ] + offset\n\nInternal Abundance Normalization:\nx_i = w_i / sum_j w_j\nw_i >= 0 for all i\nsum_i x_i = 1 after normalization\n\nModel Parameters:\nw0 ... wN-1: Generated cluster-abundance coefficients aligned to the project component rows.\nr_eff_wN: Generated per-component effective-radius parameter used when that cluster row is approximated as a sphere.\na_eff_wN, b_eff_wN, c_eff_wN: Generated per-component semiaxis parameters used when that cluster row is approximated as an ellipsoid; the hard-sphere interaction radius is then taken from the equivalent-volume sphere.\nphi_solute: Solute volume fraction used to scale the cluster contribution relative to the solvent.\nphi_int: Effective packing fraction used only inside the hard-sphere interaction model.\nsolvent_scale: Solvent trace multiplier applied before the final scale and offset.\nscale: Global multiplicative scale factor applied to the combined model.\noffset: Additive q-independent background term.\nlog_sigma: Natural logarithm of the Gaussian noise standard deviation used by the DREAM likelihood.\n\nCluster Geometry Metadata:\nThis template supports per-cluster geometry metadata and keeps both sphere and ellipsoid approximations enabled in Prefit. Use this when you want mixed geometry selection while accepting the equivalent-sphere approximation behind the hard-sphere structure factor.\n", + "description": "pyDREAM Poly LMA Hard-Sphere/Ellipsoid Mix (Approx.)\n\nStructure Factor:\nCluster-resolved hard-sphere Percus-Yevick structure factors are evaluated with one effective interaction radius per cluster profile and a shared interaction packing parameter phi_int.\n\nScientific Scope:\nThis template is an approximate mixed-shape workflow for projects that want to keep both sphere and ellipsoid geometry options active in Prefit. Ellipsoid rows are mapped onto an equivalent-sphere interaction radius before the hard-sphere Percus-Yevick structure factor is evaluated. That makes this a Hansen-style approximate hard-body-to-sphere route rather than an exact hard-ellipsoid Percus-Yevick closure.\n\nForm Factor:\nThe solute form factor is a weighted mixture of the project's averaged cluster scattering profiles. SAXSShell generates weight parameters w0 ... wN-1 from the project component order, then this template normalizes those raw weights internally before applying the per-cluster hard-sphere structure factor.\n\nModel Organization:\nThis template is intended to run through the production Prefit and DREAM workflows, not as a stand-alone example. The Prefit workflow supplies the generated component weights, solvent trace, averaged cluster intensities, and the computed cluster-geometry metadata array effective_radii. It also generates cluster-geometry parameter rows that follow the active structure-factor approximation for each mapped component: sphere rows expose one fitted radius parameter, while ellipsoid rows expose three semiaxis parameters whose equivalent-volume sphere radius is used in S(Q).\n\nModel Equation:\nI_model(q) = scale * phi_solute * sum_i x_i I_i(q) S_HS(q; R_eff_i, phi_int) + solvent_scale * (1 - phi_solute) * I_solv(q) + offset\n\nInternal Abundance Normalization:\nx_i = w_i / sum_j w_j\nw_i >= 0 for all i\nsum_i x_i = 1 after normalization\n\nModel Parameters:\nw0 ... wN-1: Generated cluster-abundance coefficients aligned to the project component rows.\nr_eff_wN: Generated per-component effective-radius parameter used when that cluster row is approximated as a sphere.\na_eff_wN, b_eff_wN, c_eff_wN: Generated per-component semiaxis parameters used when that cluster row is approximated as an ellipsoid; the hard-sphere interaction radius is then taken from the equivalent-volume sphere.\nphi_solute: SAXS-effective solute interaction ratio used in the solvent-complement term and fixed by default so the contrast-weighted solvent subtraction can be supplied from the solution-scattering estimator rather than co-fit with solvent_scale.\nphi_int: Effective packing fraction used only inside the hard-sphere interaction model.\nsolvent_scale: Solvent attenuation/subtraction weight applied to the solvent contribution, constrained to the interval [0, 1], and fixed by default so it does not become redundant with phi_solute.\nscale: Global multiplicative scale factor applied only to the solute contribution.\noffset: Additive q-independent background term.\nlog_sigma: Natural logarithm of the Gaussian noise standard deviation used by the DREAM likelihood.\n\nFitting Guidance:\nUse phi_solute and solvent_scale as prior-informed solvent-subtraction controls. They should normally be populated from the physical-volume, SAXS-effective interaction, and attenuation estimators, then kept fixed while the fit isolates the solute scattering parameters.\n\nCluster Geometry Metadata:\nThis template supports per-cluster geometry metadata and keeps both sphere and ellipsoid approximations enabled in Prefit. Use this when you want mixed geometry selection while accepting the equivalent-sphere approximation behind the hard-sphere structure factor.\n", "capabilities": { "cluster_geometry_metadata": { "supported": true, diff --git a/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_mix_approx.py b/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_mix_approx.py index a1a4581..d2f7cbe 100644 --- a/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_mix_approx.py +++ b/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_mix_approx.py @@ -17,9 +17,9 @@ # nonspherical-body-to-polydisperse-sphere literature, not an exact # hard-ellipsoid PY closure # -# param: phi_solute,0.02,True,0.0,0.5 +# param: phi_solute,0.02,False,0.0,0.5 # param: phi_int,0.02,True,0.0,0.4 -# param: solvent_scale,1.0,True,0.0,5.0 +# param: solvent_scale,1.0,False,0.0,1.0 # param: scale,1.0,True,1e-8,1e8 # param: offset,0.0,True,-1e6,1e6 # param: log_sigma,-9.21,True,-20.0,5.0 @@ -160,6 +160,44 @@ def _full_params_to_param_dict(full_params): } +def _bounded_solvent_weight(value): + return float(np.clip(float(value), 0.0, 1.0)) + + +def _effective_structure_factor_profile( + q_values, + cluster_intensities, + effective_radii, + raw_weights, + phi_int, +): + """Return the mixture-equivalent S(q) that modulates the form + factor.""" + q_values = np.asarray(q_values, dtype=float) + effective_radii = np.asarray(effective_radii, dtype=float) + raw_weights = np.asarray(raw_weights, dtype=float) + fractions = normalize_profile_fractions(raw_weights) + numerator = np.zeros_like(q_values, dtype=float) + denominator = np.zeros_like(q_values, dtype=float) + fallback = np.zeros_like(q_values, dtype=float) + + for frac, iq_cluster, radius in zip( + fractions, cluster_intensities, effective_radii + ): + iq_cluster = np.asarray(iq_cluster, dtype=float) + sq = calc_hardsphere_sq(radius, phi_int, q_values) + numerator += frac * iq_cluster * sq + denominator += frac * iq_cluster + fallback += frac * sq + + structure_factor = fallback.copy() + valid_mask = np.abs(denominator) > 1e-12 + structure_factor[valid_mask] = ( + numerator[valid_mask] / denominator[valid_mask] + ) + return structure_factor + + def polydisperse_lma_hs_model( q_values, cluster_intensities, @@ -201,10 +239,15 @@ def polydisperse_lma_hs_model( sq = calc_hardsphere_sq(radius, phi_int, q_values) solute_sum += frac * iq_cluster * sq - model = phi_solute * solute_sum - model += solvent_scale * (1.0 - phi_solute) * solvent_intensities + solvent_weight = _bounded_solvent_weight(solvent_scale) + solute_contribution = scale * phi_solute * solute_sum + solvent_contribution = ( + solvent_weight + * (1.0 - phi_solute) + * np.asarray(solvent_intensities, dtype=float) + ) - return scale * model + offset + return solute_contribution + solvent_contribution + offset def lmfit_model_profile( @@ -232,6 +275,27 @@ def lmfit_model_profile( ) +def structure_factor_profile( + q, solvent_data, model_data, effective_radii, **params +): + """Return the mixture-equivalent structure-factor trace S(q).""" + del solvent_data + weight_keys = _weight_keys_from_params(params) + raw_weights = np.asarray([params[key] for key in weight_keys], dtype=float) + resolved_effective_radii = _resolve_effective_radii( + weight_keys, + params, + effective_radii, + ) + return _effective_structure_factor_profile( + q, + model_data, + resolved_effective_radii, + raw_weights, + params["phi_int"], + ) + + def model_poly_lma_hs(params): global q_values global theoretical_intensities @@ -253,7 +317,7 @@ def model_poly_lma_hs(params): ) phi_solute = named_params["phi_solute"] phi_int = named_params["phi_int"] - solvent_scale = named_params["solvent_scale"] + solvent_scale = _bounded_solvent_weight(named_params["solvent_scale"]) scale = named_params["scale"] offset = named_params["offset"] else: @@ -262,7 +326,7 @@ def model_poly_lma_hs(params): resolved_effective_radii = np.asarray(effective_radii, dtype=float) phi_solute = params[n_profiles] phi_int = params[n_profiles + 1] - solvent_scale = params[n_profiles + 2] + solvent_scale = _bounded_solvent_weight(params[n_profiles + 2]) scale = params[n_profiles + 3] offset = params[n_profiles + 4] diff --git a/src/saxshell/saxs/_ui_assets/saxshell_icon.svg b/src/saxshell/saxs/_ui_assets/saxshell_icon.svg new file mode 100644 index 0000000..1de3b70 --- /dev/null +++ b/src/saxshell/saxs/_ui_assets/saxshell_icon.svg @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/saxshell/saxs/beam_geometry_presets.py b/src/saxshell/saxs/beam_geometry_presets.py new file mode 100644 index 0000000..97730bb --- /dev/null +++ b/src/saxshell/saxs/beam_geometry_presets.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from pathlib import Path + +from saxshell.saxs.solution_scattering_estimator import ( + DEFAULT_BEAM_FOOTPRINT_HEIGHT_MM, + DEFAULT_BEAM_FOOTPRINT_WIDTH_MM, + DEFAULT_BEAM_PROFILE, + DEFAULT_CAPILLARY_GEOMETRY, + DEFAULT_CAPILLARY_SIZE_MM, + DEFAULT_INCIDENT_ENERGY_KEV, + BeamGeometrySettings, +) + +_PRESET_DIR_NAME = "_beam_geometry_presets" +_DEFAULT_PRESETS_FILENAME = "default_beam_geometry_presets.json" +_USER_PRESETS_FILENAME = "user_beam_geometry_presets.json" +DEFAULT_BEAM_GEOMETRY_PRESET_NAME = "NSLS-II 28-ID-1 (default)" + + +@dataclass(frozen=True, slots=True) +class BeamGeometryPreset: + name: str + beam: BeamGeometrySettings + notes: str = "" + builtin: bool = False + + def to_dict(self) -> dict[str, object]: + payload: dict[str, object] = { + "incident_energy_kev": float(self.beam.incident_energy_kev), + "capillary_size_mm": float(self.beam.capillary_size_mm), + "capillary_geometry": str(self.beam.capillary_geometry), + "beam_profile": str(self.beam.beam_profile), + "beam_footprint_width_mm": float( + self.beam.beam_footprint_width_mm + ), + "beam_footprint_height_mm": float( + self.beam.beam_footprint_height_mm + ), + } + if self.notes: + payload["notes"] = self.notes + return payload + + @classmethod + def from_dict( + cls, + name: str, + payload: dict[str, object], + *, + builtin: bool = False, + ) -> "BeamGeometryPreset": + return cls( + name=name, + beam=BeamGeometrySettings( + incident_energy_kev=_positive_float( + payload.get("incident_energy_kev"), + default=DEFAULT_INCIDENT_ENERGY_KEV, + ), + capillary_size_mm=_positive_float( + payload.get("capillary_size_mm"), + default=DEFAULT_CAPILLARY_SIZE_MM, + ), + capillary_geometry=str( + payload.get( + "capillary_geometry", + DEFAULT_CAPILLARY_GEOMETRY, + ) + or DEFAULT_CAPILLARY_GEOMETRY + ).strip() + or DEFAULT_CAPILLARY_GEOMETRY, + beam_profile=str( + payload.get("beam_profile", DEFAULT_BEAM_PROFILE) + or DEFAULT_BEAM_PROFILE + ).strip() + or DEFAULT_BEAM_PROFILE, + beam_footprint_width_mm=_positive_float( + payload.get("beam_footprint_width_mm"), + default=DEFAULT_BEAM_FOOTPRINT_WIDTH_MM, + ), + beam_footprint_height_mm=_positive_float( + payload.get("beam_footprint_height_mm"), + default=DEFAULT_BEAM_FOOTPRINT_HEIGHT_MM, + ), + ), + notes=str(payload.get("notes", "") or "").strip(), + builtin=builtin, + ) + + +def _positive_float(value: object, *, default: float) -> float: + try: + parsed = float(value) + except (TypeError, ValueError): + return float(default) + if parsed <= 0.0: + return float(default) + return parsed + + +def beam_geometry_presets_dir() -> Path: + return Path(__file__).resolve().parent / _PRESET_DIR_NAME + + +def default_beam_geometry_presets_path() -> Path: + return beam_geometry_presets_dir() / _DEFAULT_PRESETS_FILENAME + + +def beam_geometry_presets_path() -> Path: + configured = os.environ.get("SAXSHELL_BEAM_GEOMETRY_PRESETS_PATH", "") + if configured.strip(): + return Path(configured).expanduser() + return beam_geometry_presets_dir() / _USER_PRESETS_FILENAME + + +def default_beam_geometry_presets() -> dict[str, BeamGeometryPreset]: + payloads = _load_preset_payloads(default_beam_geometry_presets_path()) + return { + name: BeamGeometryPreset.from_dict(name, payload, builtin=True) + for name, payload in payloads.items() + if isinstance(payload, dict) + } + + +def load_beam_geometry_presets() -> dict[str, BeamGeometryPreset]: + presets = default_beam_geometry_presets() + for name, payload in _load_custom_preset_payloads().items(): + if not isinstance(payload, dict): + continue + try: + presets[name] = BeamGeometryPreset.from_dict(name, payload) + except (TypeError, ValueError): + continue + return presets + + +def save_custom_beam_geometry_preset(preset: BeamGeometryPreset) -> Path: + file_path = beam_geometry_presets_path() + payload = _load_custom_preset_payloads() + payload[preset.name] = preset.to_dict() + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text( + json.dumps({"presets": payload}, indent=2) + "\n", + encoding="utf-8", + ) + return file_path + + +def delete_custom_beam_geometry_preset(name: str) -> bool: + file_path = beam_geometry_presets_path() + payload = _load_custom_preset_payloads() + if name not in payload: + return False + del payload[name] + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text( + json.dumps({"presets": payload}, indent=2) + "\n", + encoding="utf-8", + ) + return True + + +def ordered_beam_geometry_preset_names( + presets: dict[str, BeamGeometryPreset], +) -> list[str]: + defaults = default_beam_geometry_presets() + ordered_names = [name for name in defaults if name in presets] + custom_names = sorted(name for name in presets if name not in defaults) + override_names = sorted( + name + for name, preset in presets.items() + if name in defaults and not preset.builtin + ) + if override_names: + ordered_names = [ + name for name in ordered_names if name not in override_names + ] + override_names + return ordered_names + custom_names + + +def _load_preset_payloads(file_path: Path) -> dict[str, object]: + if not file_path.is_file(): + return {} + try: + payload = json.loads(file_path.read_text(encoding="utf-8")) + except json.JSONDecodeError: + return {} + presets = payload.get("presets", {}) + return presets if isinstance(presets, dict) else {} + + +def _load_custom_preset_payloads() -> dict[str, object]: + return _load_preset_payloads(beam_geometry_presets_path()) + + +__all__ = [ + "DEFAULT_BEAM_GEOMETRY_PRESET_NAME", + "BeamGeometryPreset", + "beam_geometry_presets_dir", + "beam_geometry_presets_path", + "default_beam_geometry_presets", + "default_beam_geometry_presets_path", + "delete_custom_beam_geometry_preset", + "load_beam_geometry_presets", + "ordered_beam_geometry_preset_names", + "save_custom_beam_geometry_preset", +] diff --git a/src/saxshell/saxs/dream/distributions.py b/src/saxshell/saxs/dream/distributions.py index 1381d50..f486783 100644 --- a/src/saxshell/saxs/dream/distributions.py +++ b/src/saxshell/saxs/dream/distributions.py @@ -2,6 +2,7 @@ import json import math +import re from dataclasses import asdict, dataclass from pathlib import Path @@ -90,6 +91,7 @@ def build_default_parameter_map( for name, value in dict(prefit_payload.get("fit_parameters", {})).items(): meta = dict(fit_parameter_meta.get(name, {})) float_value = float(value) + distribution = _default_distribution_for_fit_parameter(str(name)) entries.append( DreamParameterEntry( structure="", @@ -98,12 +100,12 @@ def build_default_parameter_map( param=str(name), value=float_value, vary=bool(meta.get("vary", True)), - distribution="norm", + distribution=distribution, dist_params=_default_distribution_params( float_value, cv_default, eps, - )["norm"], + )[distribution], ) ) return entries @@ -220,6 +222,13 @@ def _default_distribution_params( } +def _default_distribution_for_fit_parameter(name: str) -> str: + normalized_name = str(name).strip() + if re.fullmatch(r"r_eff_w\d+", normalized_name): + return "lognorm" + return "norm" + + __all__ = [ "BASE_DISTRIBUTIONS", "DreamParameterEntry", diff --git a/src/saxshell/saxs/dream/results.py b/src/saxshell/saxs/dream/results.py index cbbff80..8f2f651 100644 --- a/src/saxshell/saxs/dream/results.py +++ b/src/saxshell/saxs/dream/results.py @@ -41,6 +41,7 @@ class DreamModelPlotData: experimental_intensities: np.ndarray model_intensities: np.ndarray solvent_contribution: np.ndarray | None + structure_factor_trace: np.ndarray | None bestfit_method: str template_name: str rmse: float @@ -170,6 +171,13 @@ def __init__( self._posterior_view_cache: dict[ tuple[object, ...], _PosteriorView ] = {} + self._summary_cache: dict[tuple[object, ...], DreamSummary] = {} + self._model_plot_cache: dict[ + tuple[object, ...], DreamModelPlotData + ] = {} + self._violin_data_cache: dict[ + tuple[object, ...], DreamViolinPlotData + ] = {} self._parameter_entry_lookup = self._build_parameter_entry_lookup() self._apply_burnin() @@ -209,8 +217,19 @@ def get_summary( credible_interval_low=credible_interval_low, credible_interval_high=credible_interval_high, ) + cache_key = ( + str(bestfit_method), + str(view.filter_mode), + round(float(view.top_percent), 6), + int(view.top_n), + round(float(view.credible_interval_low), 6), + round(float(view.credible_interval_high), 6), + ) + cached = self._summary_cache.get(cache_key) + if cached is not None: + return cached best_active = self._select_best_params(bestfit_method, view) - return DreamSummary( + summary = DreamSummary( bestfit_method=bestfit_method, bestfit_params=self.expand_params(best_active), map_params=self.expand_params(view.map_params), @@ -229,6 +248,8 @@ def get_summary( credible_interval_high=float(view.credible_interval_high), run_dir=self.run_dir, ) + self._summary_cache[cache_key] = summary + return summary def build_model_fit_data( self, @@ -240,6 +261,15 @@ def build_model_fit_data( credible_interval_low: float = 16.0, credible_interval_high: float = 84.0, ) -> DreamModelPlotData: + cache_key = ( + str(bestfit_method), + str(posterior_filter_mode), + round(float(posterior_top_percent), 6), + int(posterior_top_n), + ) + cached = self._model_plot_cache.get(cache_key) + if cached is not None: + return cached summary = self.get_summary( bestfit_method=bestfit_method, posterior_filter_mode=posterior_filter_mode, @@ -274,6 +304,11 @@ def build_model_fit_data( params=params, extra_inputs=extra_inputs, ) + structure_factor_trace = self._evaluate_structure_factor_trace( + model_module, + params=params, + extra_inputs=extra_inputs, + ) residuals = np.asarray( model_intensities - self.experimental_intensities, dtype=float, @@ -290,17 +325,20 @@ def build_model_fit_data( if total_sum_squares > 0.0 else 1.0 ) - return DreamModelPlotData( + plot_data = DreamModelPlotData( q_values=self.q_values, experimental_intensities=self.experimental_intensities, model_intensities=model_intensities, solvent_contribution=solvent_contribution, + structure_factor_trace=structure_factor_trace, bestfit_method=bestfit_method, template_name=self.template_name, rmse=rmse, mean_abs_residual=mean_abs_residual, r_squared=r_squared, ) + self._model_plot_cache[cache_key] = plot_data + return plot_data def build_violin_data( self, @@ -314,6 +352,17 @@ def build_violin_data( sample_source: str = "filtered_posterior", weight_order: str = "weight_index", ) -> DreamViolinPlotData: + cache_key = ( + str(mode), + str(posterior_filter_mode), + round(float(posterior_top_percent), 6), + int(posterior_top_n), + str(sample_source), + str(weight_order), + ) + cached = self._violin_data_cache.get(cache_key) + if cached is not None: + return cached view = self._posterior_view( filter_mode=posterior_filter_mode, top_percent=posterior_top_percent, @@ -329,7 +378,11 @@ def build_violin_data( samples = np.asarray(active_samples, dtype=float) names = list(self.active_parameter_names) else: - full_samples = self._expand_sample_matrix(active_samples) + full_samples = self._full_violin_samples( + view, + active_samples=active_samples, + sample_source=sample_source, + ) if mode == "all_parameters": samples = full_samples names = list(self.full_parameter_names) @@ -338,6 +391,19 @@ def build_violin_data( full_samples, include=lambda name: name.startswith("w"), ) + elif mode == "effective_radii_only": + names, samples = self._select_columns( + full_samples, + include=self._is_effective_radius_parameter, + ) + elif mode == "additional_parameters_only": + names, samples = self._select_columns( + full_samples, + include=lambda name: ( + not name.startswith("w") + and not self._is_effective_radius_parameter(name) + ), + ) elif mode == "fit_parameters": names, samples = self._select_columns( full_samples, @@ -348,14 +414,15 @@ def build_violin_data( "Unknown DREAM violin mode: " f"{mode}. Expected one of " "'varying_parameters', 'all_parameters', " - "'weights_only', or 'fit_parameters'." + "'weights_only', 'effective_radii_only', " + "'additional_parameters_only', or 'fit_parameters'." ) names, samples = self._ordered_violin_columns( names, samples, weight_order=weight_order, ) - return DreamViolinPlotData( + violin_data = DreamViolinPlotData( parameter_names=names, display_names=[ self._parameter_display_name(name) for name in names @@ -366,6 +433,8 @@ def build_violin_data( sample_count=int(np.asarray(samples).shape[0]), weight_order=weight_order, ) + self._violin_data_cache[cache_key] = violin_data + return violin_data def save_statistics_report( self, @@ -462,6 +531,37 @@ def _evaluate_solvent_contribution( ) return np.asarray(contribution, dtype=float) + def _evaluate_structure_factor_trace( + self, + model_module, + *, + params: dict[str, float], + extra_inputs: list[np.ndarray], + ) -> np.ndarray | None: + structure_factor_function = getattr( + model_module, + "structure_factor_profile", + None, + ) + if structure_factor_function is None: + return None + try: + structure_factor = structure_factor_function( + self.q_values, + self.solvent_intensities, + self.theoretical_intensities, + *extra_inputs, + **params, + ) + except Exception: + return None + structure_factor_array = np.asarray(structure_factor, dtype=float) + if structure_factor_array.shape != self.q_values.shape: + return None + if not np.all(np.isfinite(structure_factor_array)): + return None + return structure_factor_array + def _select_best_params( self, bestfit_method: str, @@ -621,6 +721,21 @@ def _active_samples_for_source( "'map_chain_only'." ) + def _full_violin_samples( + self, + view: _PosteriorView, + *, + active_samples: np.ndarray, + sample_source: str, + ) -> np.ndarray: + if sample_source == "filtered_posterior": + flat_mask = np.asarray(view.sample_mask, dtype=bool).reshape(-1) + return np.asarray( + self._expanded_flat_samples()[flat_mask], + dtype=float, + ) + return self._expand_sample_matrix(active_samples) + def _posterior_mask( self, *, @@ -762,6 +877,17 @@ def _build_parameter_entry_lookup(self) -> dict[str, dict[str, object]]: lookup[name] = dict(entry) return lookup + @staticmethod + def _is_effective_radius_parameter(parameter_name: str) -> bool: + name = str(parameter_name).strip() + return bool( + name == "eff_r" + or name.startswith("r_eff_") + or name.startswith("a_eff_") + or name.startswith("b_eff_") + or name.startswith("c_eff_") + ) + @staticmethod def _normalize_sampled_params( raw_sampled_params: np.ndarray, diff --git a/src/saxshell/saxs/dream/runtime.py b/src/saxshell/saxs/dream/runtime.py index 58fb8e3..a48cafa 100644 --- a/src/saxshell/saxs/dream/runtime.py +++ b/src/saxshell/saxs/dream/runtime.py @@ -27,6 +27,7 @@ ) from saxshell.saxs.prefit import SAXSPrefitWorkflow from saxshell.saxs.project_manager import ( + ProjectSettings, SAXSProjectManager, build_project_paths, ) @@ -67,6 +68,23 @@ def __init__( self.settings_presets_dir = self.paths.dream_dir / "settings_presets" self.settings_presets_dir.mkdir(parents=True, exist_ok=True) + def apply_project_settings( + self, + settings: ProjectSettings, + ) -> None: + incoming_settings = ProjectSettings.from_dict(settings.to_dict()) + if incoming_settings.resolved_project_dir != self.paths.project_dir: + raise ValueError( + "Cannot apply project settings from a different SAXS project." + ) + self.settings = incoming_settings + self.paths = build_project_paths(self.settings.project_dir) + self.dream_settings_path = self.paths.dream_dir / "pd_settings.json" + self.parameter_map_path = self.paths.dream_dir / "pd_param_map.json" + self.settings_presets_dir = self.paths.dream_dir / "settings_presets" + self.settings_presets_dir.mkdir(parents=True, exist_ok=True) + self.prefit_workflow.apply_project_settings(incoming_settings) + def load_settings(self) -> DreamRunSettings: settings = load_dream_settings(self.dream_settings_path) if settings.model_name is None: diff --git a/src/saxshell/saxs/model_report.py b/src/saxshell/saxs/model_report.py new file mode 100644 index 0000000..6cadd8a --- /dev/null +++ b/src/saxshell/saxs/model_report.py @@ -0,0 +1,3032 @@ +from __future__ import annotations + +import json +import textwrap +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +from matplotlib import colormaps, rc_context +from matplotlib.colors import to_hex +from matplotlib.figure import Figure + +from saxshell.saxs.dream import ( + DreamModelPlotData, + DreamParameterEntry, + DreamRunSettings, + DreamSummary, + DreamViolinPlotData, +) +from saxshell.saxs.prefit import PrefitEvaluation, PrefitParameterEntry +from saxshell.saxs.prefit.cluster_geometry import ClusterGeometryMetadataRow +from saxshell.saxs.project_manager import PowerPointExportSettings +from saxshell.saxs.project_manager.prior_plot import plot_md_prior_histogram +from saxshell.saxs.solution_scattering_estimator import ( + SolutionScatteringEstimate, +) + +_SLIDE_WIDTH_INCHES = 13.333 +_SLIDE_HEIGHT_INCHES = 7.5 +_SLIDE_LEFT_INCHES = 0.45 +_SLIDE_TOP_INCHES = 1.1 +_SLIDE_CONTENT_WIDTH_INCHES = 12.43 +_SLIDE_CONTENT_HEIGHT_INCHES = 5.85 +_TABLE_TOP_INCHES = 1.42 +_TABLE_WIDTH_INCHES = 12.43 +_TABLE_HEIGHT_INCHES = 5.32 +_TABLE_NOTE_TOP_INCHES = 6.88 +_THICK_RULE_HEIGHT_INCHES = 0.06 +_TEXT_WRAP_FULL = 98 +_TEXT_WRAP_HALF = 43 +_TEXT_WRAP_SIDE = 37 +_TEXT_LINES_FULL = 24 +_TEXT_LINES_HALF = 23 +_TEXT_LINES_SIDE = 20 +_PREFIT_WINDOW_EXPERIMENTAL_COLOR = to_hex("black", keep_alpha=False) +_PREFIT_WINDOW_MODEL_COLOR = to_hex("tab:red", keep_alpha=False) +_PREFIT_WINDOW_SOLVENT_COLOR = to_hex("green", keep_alpha=False) +_PREFIT_WINDOW_RESIDUAL_COLOR = to_hex("tab:blue", keep_alpha=False) +_DREAM_OUTPUT_EXPERIMENTAL_COLOR = to_hex("black", keep_alpha=False) +_DREAM_OUTPUT_MODEL_COLOR = to_hex("tab:red", keep_alpha=False) +_DREAM_OUTPUT_SOLVENT_COLOR = to_hex("green", keep_alpha=False) +_DREAM_OUTPUT_STRUCTURE_FACTOR_COLOR = to_hex( + "tab:purple", + keep_alpha=False, +) + +ReportProgressCallback = Callable[[int, int, str], None] + + +@dataclass(frozen=True, slots=True) +class ReportComponentSeries: + label: str + q_values: np.ndarray + intensities: np.ndarray + color: str + + +@dataclass(frozen=True, slots=True) +class ReportComponentPlotData: + title: str + selected_q_min: float | None + selected_q_max: float | None + use_experimental_grid: bool + log_x: bool + log_y: bool + experimental_q_values: np.ndarray | None + experimental_intensities: np.ndarray | None + solvent_q_values: np.ndarray | None + solvent_intensities: np.ndarray | None + component_series: tuple[ReportComponentSeries, ...] + + +@dataclass(frozen=True, slots=True) +class PriorHistogramRequest: + title: str + json_path: Path + mode: str + cmap: str + secondary_element: str | None = None + + +@dataclass(frozen=True, slots=True) +class DreamFilterReportView: + title: str + description: str + filter_mode: str + is_active: bool + summary: DreamSummary + model_plot: DreamModelPlotData + violin_plot: DreamViolinPlotData + violin_payload: dict[str, object] + weights_violin_payload: dict[str, object] + effective_radii_violin_payload: dict[str, object] + + +@dataclass(frozen=True, slots=True) +class DreamModelReportContext: + output_path: Path + asset_dir: Path + project_name: str + project_dir: Path + generated_at: datetime + powerpoint_settings: PowerPointExportSettings + user_q_range_text: str + supported_q_range_text: str | None + q_sampling_text: str + template_name: str + template_display_name: str + template_module_path: Path | None + model_equation_text: str | None + model_context_lines: tuple[str, ...] + model_definition_lines: tuple[str, ...] + model_reference_lines: tuple[str, ...] + prior_histograms: tuple[PriorHistogramRequest, ...] + component_plot_without_solvent: ReportComponentPlotData | None + component_plot_with_solvent: ReportComponentPlotData | None + prefit_evaluation: PrefitEvaluation | None + prefit_parameter_entries: tuple[PrefitParameterEntry, ...] + prefit_statistics: dict[str, object] + cluster_geometry_rows: tuple[ClusterGeometryMetadataRow, ...] + solution_scattering_estimate: SolutionScatteringEstimate | None + dream_settings: DreamRunSettings + dream_summary: DreamSummary + dream_model_plot: DreamModelPlotData + dream_violin_plot: DreamViolinPlotData + dream_violin_payload: dict[str, object] + dream_parameter_map_entries: tuple[DreamParameterEntry, ...] + dream_filter_assessments: tuple[dict[str, object], ...] + dream_filter_views: tuple[DreamFilterReportView, ...] + output_summary_lines: tuple[str, ...] + directory_lines: tuple[str, ...] + + +@dataclass(frozen=True, slots=True) +class ModelReportExportResult: + report_path: Path + manifest_path: Path | None + figure_paths: tuple[Path, ...] + + +class _ReportProgressTracker: + def __init__( + self, + total_steps: int, + callback: ReportProgressCallback | None, + ) -> None: + self.total_steps = max(int(total_steps), 1) + self._callback = callback + self._processed = 0 + self.emit("Generating DREAM model report PowerPoint. Please wait...") + + def emit(self, message: str) -> None: + if self._callback is None: + return + self._callback(self._processed, self.total_steps, str(message)) + + def advance(self, message: str) -> None: + self._processed = min(self._processed + 1, self.total_steps) + self.emit(message) + + +def export_dream_model_report_pptx( + context: DreamModelReportContext, + *, + progress_callback: ReportProgressCallback | None = None, +) -> ModelReportExportResult: + Presentation, Inches, Pt = _load_pptx_api() + from pptx.enum.shapes import MSO_AUTO_SHAPE_TYPE + from pptx.enum.text import MSO_VERTICAL_ANCHOR, PP_ALIGN + + export_settings = PowerPointExportSettings.from_dict( + context.powerpoint_settings.to_dict() + ) + context.output_path.parent.mkdir(parents=True, exist_ok=True) + temporary_figure_dir: TemporaryDirectory[str] | None = None + if ( + export_settings.generate_manifest + or export_settings.export_figure_assets + ): + context.asset_dir.mkdir(parents=True, exist_ok=True) + if export_settings.export_figure_assets: + figure_dir = context.asset_dir / "figures" + figure_dir.mkdir(parents=True, exist_ok=True) + else: + temporary_figure_dir = TemporaryDirectory() + figure_dir = Path(temporary_figure_dir.name) + + prefit_parameter_rows = _prefit_parameter_rows( + context.prefit_parameter_entries + ) + geometry_parameter_rows = _cluster_geometry_rows( + context.cluster_geometry_rows + ) + dream_prior_rows = _dream_prior_rows(context.dream_parameter_map_entries) + prior_histogram_pages = _chunked(context.prior_histograms, 4) + prefit_table_pages = _table_row_chunks(prefit_parameter_rows, 12) + geometry_table_pages = _table_row_chunks(geometry_parameter_rows, 11) + dream_prior_table_pages = _table_row_chunks(dream_prior_rows, 11) + prefit_summary_pages = ( + _paginate_text_lines( + _prefit_summary_lines( + context.prefit_evaluation, + context.prefit_statistics, + ), + max_lines=_TEXT_LINES_SIDE, + wrap_at=_TEXT_WRAP_SIDE, + ) + if context.prefit_evaluation is not None + else [] + ) + estimator_pages = ( + _paginate_text_lines( + _solution_estimate_lines(context.solution_scattering_estimate), + max_lines=_TEXT_LINES_FULL, + wrap_at=_TEXT_WRAP_FULL, + ) + if context.solution_scattering_estimate is not None + else [] + ) + dream_settings_pages = _paginate_text_lines( + _dream_settings_lines( + context.dream_settings, + context.dream_summary, + ), + max_lines=_TEXT_LINES_HALF, + wrap_at=_TEXT_WRAP_HALF, + ) + dream_assessment_pages = _paginate_text_lines( + _dream_assessment_lines( + context.dream_filter_assessments, + context.dream_settings, + ), + max_lines=_TEXT_LINES_HALF, + wrap_at=_TEXT_WRAP_HALF, + ) + dream_output_pages = _paginate_text_lines( + _dream_output_lines( + context.dream_settings, + context.dream_summary, + context.dream_model_plot, + ), + max_lines=_TEXT_LINES_SIDE, + wrap_at=_TEXT_WRAP_SIDE, + ) + report_summary_pages = _paginate_text_lines( + list(context.output_summary_lines), + max_lines=_TEXT_LINES_HALF, + wrap_at=_TEXT_WRAP_HALF, + ) + directory_pages = _paginate_text_lines( + list(context.directory_lines), + max_lines=_TEXT_LINES_HALF, + wrap_at=_TEXT_WRAP_HALF, + ) + model_detail_lines: list[str] = [] + if context.model_equation_text: + model_detail_lines.extend( + [ + "Model equation:", + context.model_equation_text, + ] + ) + if context.model_definition_lines: + if model_detail_lines: + model_detail_lines.append("") + model_detail_lines.append("Term definitions:") + model_detail_lines.extend(context.model_definition_lines) + if context.model_reference_lines: + if model_detail_lines: + model_detail_lines.append("") + model_detail_lines.append("References:") + model_detail_lines.extend(context.model_reference_lines) + model_context_pages = _paginate_text_lines( + list(context.model_context_lines), + max_lines=_TEXT_LINES_HALF, + wrap_at=_TEXT_WRAP_HALF, + ) + model_detail_pages = _paginate_text_lines( + model_detail_lines, + max_lines=_TEXT_LINES_HALF, + wrap_at=_TEXT_WRAP_HALF, + ) + prefit_has_solvent_trace = ( + context.prefit_evaluation is not None + and _has_prefit_solvent_trace(context.prefit_evaluation) + ) + + total_slides = 1 + if context.model_context_lines or model_detail_lines: + total_slides += max(len(model_context_pages), len(model_detail_pages)) + if export_settings.include_prior_histograms and prior_histogram_pages: + total_slides += len(prior_histogram_pages) + if export_settings.include_initial_traces and ( + context.component_plot_without_solvent is not None + or context.component_plot_with_solvent is not None + ): + total_slides += 1 + if ( + export_settings.include_prefit_model + and context.prefit_evaluation is not None + ): + total_slides += len(prefit_summary_pages) + total_slides += int(prefit_has_solvent_trace) + if export_settings.include_prefit_parameters: + total_slides += len(prefit_table_pages) + if ( + export_settings.include_geometry_table + and context.cluster_geometry_rows + ): + total_slides += len(geometry_table_pages) + if ( + export_settings.include_estimator_metrics + and context.solution_scattering_estimate is not None + ): + total_slides += len(estimator_pages) + if export_settings.include_dream_settings: + total_slides += max( + len(dream_settings_pages), + len(dream_assessment_pages), + ) + if export_settings.include_dream_prior_table: + total_slides += len(dream_prior_table_pages) + if export_settings.include_dream_output_model: + total_slides += len(dream_output_pages) + if ( + export_settings.include_posterior_comparisons + and context.dream_filter_views + ): + total_slides += 4 + if ( + export_settings.include_output_summary + and export_settings.include_directory_summary + ): + total_slides += max(len(report_summary_pages), len(directory_pages)) + elif export_settings.include_output_summary: + total_slides += len(report_summary_pages) + elif export_settings.include_directory_summary: + total_slides += len(directory_pages) + + progress = _ReportProgressTracker( + _count_report_figures(context) + + total_slides + + 1 + + int(export_settings.generate_manifest), + progress_callback, + ) + + figure_paths: list[Path] = [] + rendered_figures: dict[str, Path] = {} + + if export_settings.include_prior_histograms: + for index, request in enumerate(context.prior_histograms, start=1): + figure_path = figure_dir / f"{_slugify(request.title)}.png" + rendered_figures[request.title] = _render_prior_histogram( + request, + figure_path, + settings=export_settings, + ) + if export_settings.export_figure_assets: + figure_paths.append(figure_path) + progress.advance( + "Rendered prior histogram " + f"{index}/{len(context.prior_histograms)}." + ) + + if ( + export_settings.include_initial_traces + and context.component_plot_without_solvent is not None + ): + figure_path = figure_dir / "initial_traces_no_solvent.png" + rendered_figures["initial_traces_no_solvent"] = _render_component_plot( + context.component_plot_without_solvent, + figure_path, + settings=export_settings, + ) + if export_settings.export_figure_assets: + figure_paths.append(figure_path) + progress.advance("Rendered initial SAXS traces without solvent.") + + if ( + export_settings.include_initial_traces + and context.component_plot_with_solvent is not None + ): + figure_path = figure_dir / "initial_traces_with_solvent.png" + rendered_figures["initial_traces_with_solvent"] = ( + _render_component_plot( + context.component_plot_with_solvent, + figure_path, + settings=export_settings, + ) + ) + if export_settings.export_figure_assets: + figure_paths.append(figure_path) + progress.advance("Rendered initial SAXS traces with solvent.") + + if ( + export_settings.include_prefit_model + and context.prefit_evaluation is not None + ): + figure_path = figure_dir / "prefit_model_without_solvent.png" + rendered_figures["prefit_model_without_solvent"] = _render_prefit_plot( + context.prefit_evaluation, + context.prefit_statistics, + figure_path, + settings=export_settings, + include_solvent=False, + ) + if export_settings.export_figure_assets: + figure_paths.append(figure_path) + progress.advance("Rendered prefit model plot without solvent trace.") + + if prefit_has_solvent_trace: + figure_path = figure_dir / "prefit_model_with_solvent.png" + rendered_figures["prefit_model_with_solvent"] = ( + _render_prefit_plot( + context.prefit_evaluation, + context.prefit_statistics, + figure_path, + settings=export_settings, + include_solvent=True, + ) + ) + if export_settings.export_figure_assets: + figure_paths.append(figure_path) + progress.advance("Rendered prefit model plot with solvent trace.") + + if export_settings.include_dream_output_model: + figure_path = figure_dir / "dream_model.png" + rendered_figures["dream_model"] = _render_dream_model_plot( + context.dream_model_plot, + figure_path, + settings=export_settings, + ) + if export_settings.export_figure_assets: + figure_paths.append(figure_path) + progress.advance("Rendered DREAM output model plot.") + + if ( + export_settings.include_posterior_comparisons + and context.dream_filter_views + ): + figure_path = figure_dir / "dream_filter_violin_comparison.png" + rendered_figures["dream_filter_violin_comparison"] = ( + _render_filter_violin_comparison( + context.dream_filter_views, + figure_path, + settings=export_settings, + ) + ) + if export_settings.export_figure_assets: + figure_paths.append(figure_path) + progress.advance("Rendered posterior violin comparison plot.") + + figure_path = figure_dir / "dream_filter_violin_comparison_weights.png" + rendered_figures["dream_filter_violin_comparison_weights"] = ( + _render_filter_violin_comparison( + context.dream_filter_views, + figure_path, + settings=export_settings, + payload_variant="weights", + ) + ) + if export_settings.export_figure_assets: + figure_paths.append(figure_path) + progress.advance( + "Rendered posterior violin comparison plot for weight parameters." + ) + + figure_path = ( + figure_dir / "dream_filter_violin_comparison_effective_radii.png" + ) + rendered_figures["dream_filter_violin_comparison_effective_radii"] = ( + _render_filter_violin_comparison( + context.dream_filter_views, + figure_path, + settings=export_settings, + payload_variant="effective_radii", + ) + ) + if export_settings.export_figure_assets: + figure_paths.append(figure_path) + progress.advance( + "Rendered posterior violin comparison plot for effective radii." + ) + + figure_path = figure_dir / "dream_filter_fit_comparison.png" + rendered_figures["dream_filter_fit_comparison"] = ( + _render_filter_fit_comparison( + context.dream_filter_views, + figure_path, + settings=export_settings, + ) + ) + if export_settings.export_figure_assets: + figure_paths.append(figure_path) + progress.advance("Rendered posterior fit comparison plot.") + + manifest_path: Path | None = None + if export_settings.generate_manifest: + context.asset_dir.mkdir(parents=True, exist_ok=True) + manifest_path = context.asset_dir / "report_manifest.json" + manifest_path.write_text( + json.dumps( + _manifest_payload( + context, + figure_paths=figure_paths, + ), + indent=2, + ) + + "\n", + encoding="utf-8", + ) + progress.advance("Wrote report manifest.") + + presentation = Presentation() + presentation.slide_width = Inches(_SLIDE_WIDTH_INCHES) + presentation.slide_height = Inches(_SLIDE_HEIGHT_INCHES) + blank_layout = presentation.slide_layouts[6] + slide_index = 0 + + def first_run(paragraph): + return paragraph.runs[0] if paragraph.runs else paragraph.add_run() + + def apply_run_style( + run, + *, + font_size: float, + color: str | None = None, + bold: bool = False, + ) -> None: + run.font.name = export_settings.font_family + run.font.size = Pt(font_size) + run.font.bold = bold + run.font.color.rgb = _rgb_color(color or export_settings.text_color) + + def add_title(slide, title: str, subtitle: str | None = None) -> None: + title_box = slide.shapes.add_textbox( + Inches(_SLIDE_LEFT_INCHES), + Inches(0.25), + Inches(_SLIDE_CONTENT_WIDTH_INCHES), + Inches(0.52), + ) + title_frame = title_box.text_frame + title_frame.clear() + title_frame.word_wrap = True + title_paragraph = title_frame.paragraphs[0] + title_paragraph.text = title + title_paragraph.space_after = Pt(0) + apply_run_style(first_run(title_paragraph), font_size=24, bold=True) + if subtitle: + subtitle_box = slide.shapes.add_textbox( + Inches(_SLIDE_LEFT_INCHES), + Inches(0.72), + Inches(_SLIDE_CONTENT_WIDTH_INCHES), + Inches(0.3), + ) + subtitle_frame = subtitle_box.text_frame + subtitle_frame.clear() + subtitle_frame.word_wrap = True + subtitle_paragraph = subtitle_frame.paragraphs[0] + subtitle_paragraph.text = subtitle + subtitle_paragraph.space_after = Pt(0) + apply_run_style( + first_run(subtitle_paragraph), + font_size=11, + color=export_settings.text_color, + ) + + def add_text_block( + slide, + *, + left: float, + top: float, + width: float, + height: float, + lines: list[str], + font_size: float = 13, + bold_first: bool = False, + align=PP_ALIGN.LEFT, + ) -> None: + box = slide.shapes.add_textbox( + Inches(left), + Inches(top), + Inches(width), + Inches(height), + ) + frame = box.text_frame + frame.clear() + frame.word_wrap = True + frame.margin_left = 0 + frame.margin_right = 0 + frame.margin_top = 0 + frame.margin_bottom = 0 + if not lines: + lines = [""] + for index, line in enumerate(lines): + paragraph = ( + frame.paragraphs[0] if index == 0 else frame.add_paragraph() + ) + paragraph.text = str(line) + paragraph.alignment = align + paragraph.space_after = Pt(0) + paragraph.space_before = Pt(0) + apply_run_style( + first_run(paragraph), + font_size=font_size, + bold=bool(bold_first and index == 0), + ) + + def add_picture( + slide, + image_path: Path, + *, + left: float, + top: float, + width: float, + height: float, + ) -> None: + fitted_left, fitted_top, fitted_width, fitted_height = ( + _fit_image_in_box( + image_path, + left=left, + top=top, + max_width=width, + max_height=height, + ) + ) + slide.shapes.add_picture( + str(image_path), + Inches(fitted_left), + Inches(fitted_top), + width=Inches(fitted_width), + height=Inches(fitted_height), + ) + + def add_table_header_rule( + slide, + *, + left: float, + top: float, + width: float, + ) -> None: + rule = slide.shapes.add_shape( + MSO_AUTO_SHAPE_TYPE.RECTANGLE, + Inches(left), + Inches(top), + Inches(width), + Inches(_THICK_RULE_HEIGHT_INCHES), + ) + rule.fill.solid() + rule.fill.fore_color.rgb = _rgb_color(export_settings.table_rule_color) + rule.line.fill.background() + + def style_table_cell( + cell, + *, + text: str, + font_size: float, + fill_color: str, + bold: bool = False, + align=PP_ALIGN.LEFT, + ) -> None: + cell.text = str(text) + cell.fill.solid() + cell.fill.fore_color.rgb = _rgb_color(fill_color) + cell.margin_left = Inches(0.04) + cell.margin_right = Inches(0.04) + cell.margin_top = Inches(0.02) + cell.margin_bottom = Inches(0.02) + cell.vertical_anchor = MSO_VERTICAL_ANCHOR.MIDDLE + frame = cell.text_frame + frame.word_wrap = True + paragraph = frame.paragraphs[0] + paragraph.alignment = align + paragraph.space_after = Pt(0) + paragraph.space_before = Pt(0) + apply_run_style(first_run(paragraph), font_size=font_size, bold=bold) + + def register_slide(message: str) -> None: + nonlocal slide_index + slide_index += 1 + progress.advance( + f"Built slide {slide_index}/{total_slides}: {message}" + ) + + def add_table_slides( + title: str, + columns: list[str], + rows: list[list[str]], + *, + rows_per_slide: int = 12, + subtitle: str | None = None, + note: str | None = None, + column_width_weights: Sequence[float] | None = None, + ) -> None: + if not rows: + slide = presentation.slides.add_slide(blank_layout) + add_title(slide, title, subtitle) + add_text_block( + slide, + left=_SLIDE_LEFT_INCHES, + top=1.3, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=1.0, + lines=[note or "No data are available for this section."], + font_size=14, + ) + register_slide(title) + return + chunks = _table_row_chunks(rows, rows_per_slide) + for chunk_index, chunk in enumerate(chunks, start=1): + slide = presentation.slides.add_slide(blank_layout) + effective_title = _page_title(title, chunk_index, len(chunks)) + add_title(slide, effective_title, subtitle) + table_height = ( + _TABLE_HEIGHT_INCHES - 0.28 if note else _TABLE_HEIGHT_INCHES + ) + add_table_header_rule( + slide, + left=_SLIDE_LEFT_INCHES, + top=_TABLE_TOP_INCHES, + width=_TABLE_WIDTH_INCHES, + ) + table_shape = slide.shapes.add_table( + len(chunk) + 1, + len(columns), + Inches(_SLIDE_LEFT_INCHES), + Inches(_TABLE_TOP_INCHES), + Inches(_TABLE_WIDTH_INCHES), + Inches(table_height), + ) + table = table_shape.table + column_widths = _resolve_column_widths( + columns, + chunk, + total_width=_TABLE_WIDTH_INCHES, + column_width_weights=column_width_weights, + ) + for column_index, column_width in enumerate(column_widths): + table.columns[column_index].width = Inches(column_width) + row_height = table_height / max(len(chunk) + 1, 1) + for row_index in range(len(chunk) + 1): + table.rows[row_index].height = Inches(row_height) + for column_index, column_name in enumerate(columns): + style_table_cell( + table.cell(0, column_index), + text=column_name, + font_size=11, + fill_color=export_settings.table_header_fill, + bold=True, + align=PP_ALIGN.CENTER, + ) + for row_index, row in enumerate(chunk, start=1): + row_fill = ( + export_settings.table_even_row_fill + if row_index % 2 == 1 + else export_settings.table_odd_row_fill + ) + for column_index, value in enumerate(row): + style_table_cell( + table.cell(row_index, column_index), + text=str(value), + font_size=10, + fill_color=row_fill, + ) + if note and chunk_index == len(chunks): + add_text_block( + slide, + left=_SLIDE_LEFT_INCHES, + top=_TABLE_NOTE_TOP_INCHES, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=0.24, + lines=[note], + font_size=10, + ) + register_slide(effective_title) + + def add_full_width_text_slides( + title: str, + subtitle: str | None, + pages: list[list[str]], + *, + font_size: float = 13, + ) -> None: + total_pages = len(pages) + for page_index, page_lines in enumerate(pages, start=1): + slide = presentation.slides.add_slide(blank_layout) + effective_title = _page_title(title, page_index, total_pages) + add_title(slide, effective_title, subtitle) + add_text_block( + slide, + left=_SLIDE_LEFT_INCHES, + top=1.25, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=_SLIDE_CONTENT_HEIGHT_INCHES, + lines=page_lines, + font_size=font_size, + ) + register_slide(effective_title) + + def add_two_column_text_slides( + title: str, + subtitle: str | None, + left_pages: list[list[str]], + right_pages: list[list[str]], + ) -> None: + total_pages = max(len(left_pages), len(right_pages)) + for page_index in range(total_pages): + slide = presentation.slides.add_slide(blank_layout) + effective_title = _page_title(title, page_index + 1, total_pages) + add_title(slide, effective_title, subtitle) + add_text_block( + slide, + left=0.55, + top=1.12, + width=5.85, + height=5.82, + lines=( + left_pages[page_index] + if page_index < len(left_pages) + else [] + ), + font_size=11.5, + ) + add_text_block( + slide, + left=6.7, + top=1.12, + width=5.78, + height=5.82, + lines=( + right_pages[page_index] + if page_index < len(right_pages) + else [] + ), + font_size=11.5, + ) + register_slide(effective_title) + + def add_picture_with_summary_slides( + title: str, + subtitle: str | None, + image_path: Path, + summary_pages: list[list[str]], + ) -> None: + total_pages = len(summary_pages) + for page_index, page_lines in enumerate(summary_pages, start=1): + slide = presentation.slides.add_slide(blank_layout) + effective_title = _page_title(title, page_index, total_pages) + add_title(slide, effective_title, subtitle) + if page_index == 1: + add_picture( + slide, + image_path, + left=0.45, + top=1.16, + width=7.55, + height=5.5, + ) + add_text_block( + slide, + left=8.2, + top=1.2, + width=4.22, + height=5.26, + lines=page_lines, + font_size=11.5, + ) + else: + add_text_block( + slide, + left=_SLIDE_LEFT_INCHES, + top=1.25, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=_SLIDE_CONTENT_HEIGHT_INCHES, + lines=page_lines, + font_size=13, + ) + register_slide(effective_title) + + def add_full_width_picture_slide( + title: str, + subtitle: str | None, + image_path: Path, + ) -> None: + slide = presentation.slides.add_slide(blank_layout) + add_title(slide, title, subtitle) + add_picture( + slide, + image_path, + left=0.45, + top=1.12, + width=12.43, + height=5.58, + ) + register_slide(title) + + slide = presentation.slides.add_slide(blank_layout) + add_title(slide, "SAXS Model Report", "DREAM fit export") + add_text_block( + slide, + left=0.6, + top=1.25, + width=_SLIDE_CONTENT_WIDTH_INCHES, + height=_SLIDE_CONTENT_HEIGHT_INCHES, + lines=_paginate_text_lines( + [ + f"Project: {context.project_name}", + f"Project directory: {context.project_dir}", + ( + "Generated: " + f"{context.generated_at.strftime('%Y-%m-%d %H:%M:%S')}" + ), + f"Template: {context.template_name}", + f"User selected q-range: {context.user_q_range_text}", + ( + "Supported component q-range: " + f"{context.supported_q_range_text or 'Unavailable'}" + ), + f"q-grid: {context.q_sampling_text}", + f"DREAM run directory: {context.dream_summary.run_dir}", + ( + "Posterior filter: " + f"{_describe_posterior_filter(context.dream_settings)}" + ), + ( + "Posterior samples kept: " + f"{context.dream_summary.posterior_sample_count}" + ), + ], + max_lines=_TEXT_LINES_FULL, + wrap_at=_TEXT_WRAP_FULL, + )[0], + font_size=15, + ) + register_slide("SAXS Model Report") + + if context.model_context_lines or model_detail_lines: + add_two_column_text_slides( + "Model Information", + (f"{context.template_display_name} " f"({context.template_name})"), + model_context_pages, + model_detail_pages, + ) + + if export_settings.include_prior_histograms: + for page_index, request_page in enumerate( + prior_histogram_pages, + start=1, + ): + slide = presentation.slides.add_slide(blank_layout) + effective_title = _page_title( + "Prior Histograms", + page_index, + len(prior_histogram_pages), + ) + add_title( + slide, + effective_title, + "Configured palettes for regular and solvent-sort views.", + ) + grid_positions = [ + (0.45, 1.08, 5.86, 2.52), + (6.97, 1.08, 5.86, 2.52), + (0.45, 4.0, 5.86, 2.52), + (6.97, 4.0, 5.86, 2.52), + ] + for request, (left, top, width, height) in zip( + request_page, + grid_positions, + strict=False, + ): + figure_path = rendered_figures.get(request.title) + if figure_path is None: + continue + add_picture( + slide, + figure_path, + left=left, + top=top, + width=width, + height=height, + ) + register_slide(effective_title) + + if export_settings.include_initial_traces and ( + context.component_plot_without_solvent is not None + or context.component_plot_with_solvent is not None + ): + slide = presentation.slides.add_slide(blank_layout) + add_title( + slide, + "Initial SAXS Traces", + "Dual-axis rescaled views of the selected q-range.", + ) + component_figure_keys = [ + key + for key in ( + "initial_traces_no_solvent", + "initial_traces_with_solvent", + ) + if key in rendered_figures + ] + if len(component_figure_keys) == 1: + add_picture( + slide, + rendered_figures[component_figure_keys[0]], + left=0.65, + top=1.2, + width=12.0, + height=5.35, + ) + else: + if "initial_traces_no_solvent" in rendered_figures: + add_picture( + slide, + rendered_figures["initial_traces_no_solvent"], + left=0.45, + top=1.2, + width=5.88, + height=5.25, + ) + if "initial_traces_with_solvent" in rendered_figures: + add_picture( + slide, + rendered_figures["initial_traces_with_solvent"], + left=6.95, + top=1.2, + width=5.88, + height=5.25, + ) + register_slide("Initial SAXS Traces") + + if ( + export_settings.include_prefit_model + and context.prefit_evaluation is not None + and "prefit_model_without_solvent" in rendered_figures + ): + add_picture_with_summary_slides( + "Prefit Model", + "Prefit window default view without the solvent trace.", + rendered_figures["prefit_model_without_solvent"], + prefit_summary_pages, + ) + if "prefit_model_with_solvent" in rendered_figures: + add_full_width_picture_slide( + "Prefit Model With Solvent", + "Same prefit fit with the solvent contribution trace enabled.", + rendered_figures["prefit_model_with_solvent"], + ) + + if export_settings.include_prefit_parameters: + add_table_slides( + "Prefit Parameters", + [ + "Parameter", + "Category", + "Value", + "Vary", + "Min", + "Max", + "Structure", + "Motif", + ], + prefit_parameter_rows, + rows_per_slide=12, + column_width_weights=[ + 1.65, + 1.2, + 0.8, + 0.72, + 0.82, + 0.82, + 1.05, + 1.05, + ], + ) + + if ( + export_settings.include_geometry_table + and context.cluster_geometry_rows + ): + add_table_slides( + "Computed Geometry Parameters", + [ + "Cluster", + "Mapped Parameter", + "Approx.", + "Eff. Radius", + "Rg", + "Max Radius", + "Anisotropy", + "Axes (a/b/c)", + ], + geometry_parameter_rows, + rows_per_slide=11, + note=( + "Geometry metrics come from the prefit estimator table " + "saved with the active project." + ), + column_width_weights=[1.0, 1.6, 0.8, 0.9, 0.8, 0.95, 0.95, 1.7], + ) + + if ( + export_settings.include_estimator_metrics + and context.solution_scattering_estimate is not None + ): + add_full_width_text_slides( + "Estimator Metrics", + "Volume fraction, attenuation, fluorescence, and number density.", + estimator_pages, + font_size=12.5, + ) + + if export_settings.include_dream_settings: + add_two_column_text_slides( + "DREAM Settings", + "Sampler settings, posterior filtering, and active DREAM summary.", + dream_settings_pages, + dream_assessment_pages, + ) + + if export_settings.include_dream_prior_table: + add_table_slides( + "DREAM Prior Distributions", + [ + "Parameter", + "Type", + "Structure", + "Motif", + "Value", + "Vary", + "Distribution", + "Distribution Params", + ], + dream_prior_rows, + rows_per_slide=11, + column_width_weights=[1.2, 1.0, 0.9, 0.9, 0.78, 0.72, 1.08, 2.3], + ) + + if ( + export_settings.include_dream_output_model + and "dream_model" in rendered_figures + ): + add_picture_with_summary_slides( + "DREAM Output Model", + "Best-fit model and posterior summary statistics.", + rendered_figures["dream_model"], + dream_output_pages, + ) + + if ( + export_settings.include_posterior_comparisons + and "dream_filter_violin_comparison" in rendered_figures + ): + slide = presentation.slides.add_slide(blank_layout) + add_title( + slide, + "Posterior Violin Comparison", + "Default posterior filters with the active selection labeled.", + ) + add_picture( + slide, + rendered_figures["dream_filter_violin_comparison"], + left=0.45, + top=1.12, + width=12.43, + height=5.58, + ) + register_slide("Posterior Violin Comparison") + + if ( + export_settings.include_posterior_comparisons + and "dream_filter_violin_comparison_weights" in rendered_figures + ): + slide = presentation.slides.add_slide(blank_layout) + add_title( + slide, + "Posterior Violin Comparison - Weights", + "Weight parameters (w##) shown on a dedicated y-axis scale.", + ) + add_picture( + slide, + rendered_figures["dream_filter_violin_comparison_weights"], + left=0.45, + top=1.12, + width=12.43, + height=5.58, + ) + register_slide("Posterior Violin Comparison - Weights") + + if ( + export_settings.include_posterior_comparisons + and "dream_filter_violin_comparison_effective_radii" + in rendered_figures + ): + slide = presentation.slides.add_slide(blank_layout) + add_title( + slide, + "Posterior Violin Comparison - Effective Radii", + "Effective-radius parameters shown on a dedicated y-axis scale.", + ) + add_picture( + slide, + rendered_figures["dream_filter_violin_comparison_effective_radii"], + left=0.45, + top=1.12, + width=12.43, + height=5.58, + ) + register_slide("Posterior Violin Comparison - Effective Radii") + + if ( + export_settings.include_posterior_comparisons + and "dream_filter_fit_comparison" in rendered_figures + ): + slide = presentation.slides.add_slide(blank_layout) + add_title( + slide, + "Filter Fit Comparison", + "Corresponding fits and metrics for each posterior filter view.", + ) + add_picture( + slide, + rendered_figures["dream_filter_fit_comparison"], + left=0.45, + top=1.12, + width=12.43, + height=5.58, + ) + register_slide("Filter Fit Comparison") + + if ( + export_settings.include_output_summary + and export_settings.include_directory_summary + ): + add_two_column_text_slides( + "Output Summary", + "Where to find the exported report and related figure data.", + report_summary_pages, + directory_pages, + ) + elif export_settings.include_output_summary: + add_full_width_text_slides( + "Output Summary", + "Summary information for the exported report.", + report_summary_pages, + ) + elif export_settings.include_directory_summary: + add_full_width_text_slides( + "Output Directories", + "Project and report paths related to this export.", + directory_pages, + ) + + presentation.save(str(context.output_path)) + progress.advance("Saved PowerPoint report.") + result = ModelReportExportResult( + report_path=context.output_path, + manifest_path=manifest_path, + figure_paths=tuple(figure_paths), + ) + if temporary_figure_dir is not None: + temporary_figure_dir.cleanup() + return result + + +def _load_pptx_api(): + try: + from pptx import Presentation + from pptx.util import Inches, Pt + except ImportError as exc: + raise RuntimeError( + "PowerPoint export requires the optional dependency " + "`python-pptx`. Install it and retry." + ) from exc + return Presentation, Inches, Pt + + +def _count_report_figures(context: DreamModelReportContext) -> int: + settings = PowerPointExportSettings.from_dict( + context.powerpoint_settings.to_dict() + ) + return ( + ( + len(context.prior_histograms) + if settings.include_prior_histograms + else 0 + ) + + ( + int(context.component_plot_without_solvent is not None) + + int(context.component_plot_with_solvent is not None) + if settings.include_initial_traces + else 0 + ) + + ( + int(context.prefit_evaluation is not None) + + int( + context.prefit_evaluation is not None + and _has_prefit_solvent_trace(context.prefit_evaluation) + ) + if settings.include_prefit_model + else 0 + ) + + int(settings.include_dream_output_model) + + ( + 4 + if settings.include_posterior_comparisons + and context.dream_filter_views + else 0 + ) + ) + + +def _chunked(values: Sequence[object], size: int) -> list[list[object]]: + if size <= 0: + raise ValueError("chunk size must be positive") + if not values: + return [] + return [ + list(values[index : index + size]) + for index in range(0, len(values), size) + ] + + +def _table_row_chunks( + rows: Sequence[list[str]], + rows_per_slide: int, +) -> list[list[list[str]]]: + if rows_per_slide <= 0: + raise ValueError("rows_per_slide must be positive") + if not rows: + return [[]] + return [ + list(rows[index : index + rows_per_slide]) + for index in range(0, len(rows), rows_per_slide) + ] + + +def _paginate_text_lines( + lines: Sequence[str], + *, + max_lines: int, + wrap_at: int, +) -> list[list[str]]: + if max_lines <= 0: + raise ValueError("max_lines must be positive") + wrapped_lines: list[str] = [] + for raw_line in lines: + line = str(raw_line).strip() + if not line: + wrapped_lines.append("") + continue + wrapped_lines.extend( + textwrap.wrap( + line, + width=max(int(wrap_at), 1), + break_long_words=True, + break_on_hyphens=False, + ) + or [line] + ) + if not wrapped_lines: + return [[]] + return [ + wrapped_lines[index : index + max_lines] + for index in range(0, len(wrapped_lines), max_lines) + ] + + +def _page_title(title: str, page_index: int, total_pages: int) -> str: + if total_pages <= 1: + return title + return f"{title} ({page_index}/{total_pages})" + + +def _resolve_column_widths( + columns: Sequence[str], + rows: Sequence[Sequence[str]], + *, + total_width: float, + column_width_weights: Sequence[float] | None = None, +) -> list[float]: + if not columns: + return [] + if column_width_weights is not None: + weights = [max(float(weight), 0.2) for weight in column_width_weights] + else: + weights = [] + for column_index, column_name in enumerate(columns): + max_length = len(str(column_name)) + for row in rows: + if column_index >= len(row): + continue + max_length = max(max_length, len(str(row[column_index]))) + weights.append(max(0.8, min(float(max_length) ** 0.78, 3.2))) + if len(weights) != len(columns): + raise ValueError("column width weights must match the column count") + scale = total_width / sum(weights) + widths = [weight * scale for weight in weights] + widths[-1] += total_width - sum(widths) + return widths + + +def _fit_image_in_box( + image_path: Path, + *, + left: float, + top: float, + max_width: float, + max_height: float, +) -> tuple[float, float, float, float]: + try: + from PIL import Image + except ImportError: + return left, top, max_width, max_height + + with Image.open(image_path) as image: + width_px, height_px = image.size + if width_px <= 0 or height_px <= 0: + return left, top, max_width, max_height + image_ratio = width_px / height_px + box_ratio = max_width / max_height if max_height > 0 else image_ratio + if image_ratio >= box_ratio: + fitted_width = max_width + fitted_height = fitted_width / image_ratio + else: + fitted_height = max_height + fitted_width = fitted_height * image_ratio + fitted_left = left + (max_width - fitted_width) / 2.0 + fitted_top = top + (max_height - fitted_height) / 2.0 + return fitted_left, fitted_top, fitted_width, fitted_height + + +def _report_figure_context( + settings: PowerPointExportSettings, + *, + compact: bool = False, +): + return rc_context( + { + "font.family": settings.font_family, + "font.sans-serif": [ + settings.font_family, + "DejaVu Sans", + "Liberation Sans", + "sans-serif", + ], + "axes.titlesize": 10 if compact else 12, + "axes.labelsize": 9 if compact else 10.5, + "xtick.labelsize": 8 if compact else 9, + "ytick.labelsize": 8 if compact else 9, + "legend.fontsize": 7 if compact else 8, + "figure.facecolor": "white", + "axes.facecolor": "white", + } + ) + + +def _save_figure(fig: Figure, output_path: Path) -> Path: + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=240, bbox_inches="tight") + return output_path + + +def _render_prior_histogram( + request: PriorHistogramRequest, + output_path: Path, + *, + settings: PowerPointExportSettings, +) -> Path: + with _report_figure_context(settings): + figure = Figure(figsize=(6.8, 3.4)) + axis = figure.add_subplot(111) + try: + plot_md_prior_histogram( + request.json_path, + mode=request.mode, + secondary_element=request.secondary_element, + cmap=request.cmap, + ax=axis, + ) + except Exception as exc: + axis.text( + 0.5, + 0.5, + f"Unable to render {request.title}.\n{exc}", + ha="center", + va="center", + ) + axis.set_axis_off() + figure.tight_layout() + return _save_figure(figure, output_path) + + +def _render_component_plot( + plot_data: ReportComponentPlotData, + output_path: Path, + *, + settings: PowerPointExportSettings, +) -> Path: + with _report_figure_context(settings): + figure = Figure(figsize=(7.2, 4.4)) + experimental_axis = figure.add_subplot(111) + component_axis = ( + experimental_axis.twinx() if plot_data.component_series else None + ) + legend_lines: list[object] = [] + legend_labels: list[str] = [] + + if ( + plot_data.experimental_q_values is not None + and plot_data.experimental_intensities is not None + ): + q_values = np.asarray(plot_data.experimental_q_values, dtype=float) + intensities = np.asarray( + plot_data.experimental_intensities, + dtype=float, + ) + (full_line,) = experimental_axis.plot( + q_values, + intensities, + color=settings.experimental_trace_color, + alpha=0.32, + linewidth=1.2, + label="Experimental data", + ) + legend_lines.append(full_line) + legend_labels.append("Experimental data") + selected_mask = _selected_q_mask( + q_values, + plot_data.selected_q_min, + plot_data.selected_q_max, + use_experimental_grid=plot_data.use_experimental_grid, + ) + if np.any(selected_mask) and not np.all(selected_mask): + (selected_line,) = experimental_axis.plot( + q_values[selected_mask], + intensities[selected_mask], + color=settings.experimental_trace_color, + linewidth=1.8, + label="Selected q-range", + ) + legend_lines.append(selected_line) + legend_labels.append("Selected q-range") + + if ( + plot_data.solvent_q_values is not None + and plot_data.solvent_intensities is not None + ): + q_values = np.asarray(plot_data.solvent_q_values, dtype=float) + intensities = np.asarray( + plot_data.solvent_intensities, dtype=float + ) + (solvent_line,) = experimental_axis.plot( + q_values, + intensities, + color=settings.solvent_trace_color, + alpha=0.42, + linewidth=1.2, + label="Solvent data", + ) + legend_lines.append(solvent_line) + legend_labels.append("Solvent data") + selected_mask = _selected_q_mask( + q_values, + plot_data.selected_q_min, + plot_data.selected_q_max, + use_experimental_grid=plot_data.use_experimental_grid, + ) + if np.any(selected_mask) and not np.all(selected_mask): + (selected_solvent_line,) = experimental_axis.plot( + q_values[selected_mask], + intensities[selected_mask], + color=settings.solvent_trace_color, + linewidth=1.8, + label="Selected solvent q-range", + ) + legend_lines.append(selected_solvent_line) + legend_labels.append("Selected solvent q-range") + + if plot_data.component_series: + target_axis = component_axis or experimental_axis + for series in plot_data.component_series: + (line,) = target_axis.plot( + np.asarray(series.q_values, dtype=float), + np.asarray(series.intensities, dtype=float), + color=series.color, + linewidth=1.35, + label=series.label, + ) + legend_lines.append(line) + legend_labels.append(series.label) + + _apply_axis_scale( + experimental_axis, + log_x=plot_data.log_x, + log_y=plot_data.log_y, + ) + if component_axis is not None: + _apply_axis_scale( + component_axis, + log_x=plot_data.log_x, + log_y=plot_data.log_y, + ) + component_axis.set_ylabel("Component intensity (arb. units)") + _autoscale_component_plot_to_model_range( + experimental_axis, + component_axis, + plot_data, + ) + else: + q_bounds = _component_q_bounds(plot_data) + if q_bounds is not None: + experimental_axis.set_xlim(*q_bounds) + + experimental_axis.set_xlabel("q (1/A)") + experimental_axis.set_ylabel("Intensity (arb. units)") + experimental_axis.set_title(plot_data.title, fontsize=12) + experimental_axis.tick_params(labelsize=9) + if component_axis is not None: + component_axis.tick_params(labelsize=9) + if legend_lines: + columns = max(1, int(np.ceil(len(legend_lines) / 6.0))) + experimental_axis.legend( + legend_lines, + legend_labels, + fontsize=8, + loc="upper right", + ncols=columns, + framealpha=0.92, + ) + figure.tight_layout() + return _save_figure(figure, output_path) + + +def _render_prefit_plot( + evaluation: PrefitEvaluation, + statistics: dict[str, object], + output_path: Path, + *, + settings: PowerPointExportSettings, + include_solvent: bool, +) -> Path: + with _report_figure_context(settings): + figure = Figure(figsize=(7.6, 4.95)) + has_experimental = evaluation.experimental_intensities is not None + has_residuals = evaluation.residuals is not None + if has_experimental and has_residuals: + grid = figure.add_gridspec(2, 1, height_ratios=[3, 1]) + top_axis = figure.add_subplot(grid[0, 0]) + bottom_axis = figure.add_subplot(grid[1, 0], sharex=top_axis) + else: + top_axis = figure.add_subplot(111) + bottom_axis = None + _draw_prefit_plot_axes( + top_axis, + bottom_axis, + evaluation, + include_solvent=include_solvent, + settings=settings, + ) + figure.tight_layout() + return _save_figure(figure, output_path) + + +def _render_dream_model_plot( + model_plot: DreamModelPlotData, + output_path: Path, + *, + settings: PowerPointExportSettings, +) -> Path: + with _report_figure_context(settings): + figure = Figure(figsize=(7.6, 4.95)) + grid = figure.add_gridspec(2, 1, height_ratios=[3, 1]) + top_axis = figure.add_subplot(grid[0, 0]) + bottom_axis = figure.add_subplot(grid[1, 0], sharex=top_axis) + _draw_model_fit_axis( + top_axis, + q_values=np.asarray(model_plot.q_values, dtype=float), + experimental=np.asarray( + model_plot.experimental_intensities, + dtype=float, + ), + model=np.asarray(model_plot.model_intensities, dtype=float), + solvent=( + None + if model_plot.solvent_contribution is None + else np.asarray(model_plot.solvent_contribution, dtype=float) + ), + structure_factor=( + None + if model_plot.structure_factor_trace is None + else np.asarray(model_plot.structure_factor_trace, dtype=float) + ), + title=f"DREAM output ({model_plot.bestfit_method})", + metrics_lines=[ + f"Template: {model_plot.template_name}", + f"RMSE: {model_plot.rmse:.6g}", + ("Mean |res|: " f"{model_plot.mean_abs_residual:.6g}"), + f"R^2: {model_plot.r_squared:.6g}", + ], + show_legend=True, + settings=settings, + ) + residuals = np.asarray( + model_plot.model_intensities - model_plot.experimental_intensities, + dtype=float, + ) + bottom_axis.axhline(0.0, color="#6b7280", linewidth=1.0) + bottom_axis.plot( + np.asarray(model_plot.q_values, dtype=float), + residuals, + color=settings.residual_trace_color, + linewidth=1.2, + ) + bottom_axis.set_xscale("log") + bottom_axis.set_xlabel("q (1/A)") + bottom_axis.set_ylabel("Residual") + top_axis.tick_params(labelsize=9) + bottom_axis.tick_params(labelsize=9) + figure.tight_layout() + return _save_figure(figure, output_path) + + +def _render_filter_violin_comparison( + filter_views: tuple[DreamFilterReportView, ...], + output_path: Path, + *, + settings: PowerPointExportSettings, + payload_variant: str = "default", +) -> Path: + with _report_figure_context(settings, compact=True): + rows, columns = _comparison_grid_shape(len(filter_views)) + figure = Figure(figsize=(12.0, 6.8)) + axes = figure.subplots(rows, columns, squeeze=False) + axes_flat = list(axes.ravel()) + payloads = _comparison_violin_payloads( + filter_views, + payload_variant=payload_variant, + ) + for axis, view, payload in zip( + axes_flat, + filter_views, + payloads, + strict=False, + ): + _draw_violin_axis( + axis, + view, + settings=settings, + payload=payload, + ) + for axis in axes_flat[len(filter_views) :]: + axis.set_axis_off() + figure.tight_layout() + return _save_figure(figure, output_path) + + +def _render_filter_fit_comparison( + filter_views: tuple[DreamFilterReportView, ...], + output_path: Path, + *, + settings: PowerPointExportSettings, +) -> Path: + with _report_figure_context(settings, compact=True): + rows, columns = _comparison_grid_shape(len(filter_views)) + figure = Figure(figsize=(12.0, 6.8)) + axes = figure.subplots(rows, columns, squeeze=False) + axes_flat = list(axes.ravel()) + for axis, view in zip(axes_flat, filter_views, strict=False): + _draw_model_fit_axis( + axis, + q_values=np.asarray(view.model_plot.q_values, dtype=float), + experimental=np.asarray( + view.model_plot.experimental_intensities, + dtype=float, + ), + model=np.asarray( + view.model_plot.model_intensities, dtype=float + ), + solvent=None, + structure_factor=( + None + if view.model_plot.structure_factor_trace is None + else np.asarray( + view.model_plot.structure_factor_trace, + dtype=float, + ) + ), + title=view.title, + metrics_lines=[ + f"RMSE: {view.model_plot.rmse:.5g}", + ( + "Mean |res|: " + f"{view.model_plot.mean_abs_residual:.5g}" + ), + f"R^2: {view.model_plot.r_squared:.5g}", + f"Samples: {view.summary.posterior_sample_count}", + ], + show_legend=False, + compact=True, + dream_output_style=True, + settings=settings, + ) + for axis in axes_flat[len(filter_views) :]: + axis.set_axis_off() + figure.tight_layout() + return _save_figure(figure, output_path) + + +def _draw_prefit_plot_axes( + top_axis, + bottom_axis, + evaluation: PrefitEvaluation, + *, + include_solvent: bool, + settings: PowerPointExportSettings, +) -> None: + del settings + q_values = np.asarray(evaluation.q_values, dtype=float) + plotted_lines: list[object] = [] + if evaluation.experimental_intensities is not None: + experimental = np.asarray( + evaluation.experimental_intensities, + dtype=float, + ) + (experimental_line,) = top_axis.plot( + q_values, + experimental, + color=_PREFIT_WINDOW_EXPERIMENTAL_COLOR, + label="Experimental", + ) + plotted_lines.append(experimental_line) + if include_solvent and evaluation.solvent_contribution is not None: + solvent_values = np.asarray( + evaluation.solvent_contribution, + dtype=float, + ) + solvent_mask = np.isfinite(solvent_values) & (solvent_values > 0.0) + if np.any(solvent_mask): + (solvent_line,) = top_axis.plot( + q_values[solvent_mask], + solvent_values[solvent_mask], + color=_PREFIT_WINDOW_SOLVENT_COLOR, + linewidth=1.5, + label="Solvent contribution", + ) + plotted_lines.append(solvent_line) + (model_line,) = top_axis.plot( + q_values, + np.asarray(evaluation.model_intensities, dtype=float), + color=_PREFIT_WINDOW_MODEL_COLOR, + label="Model", + ) + plotted_lines.append(model_line) + top_axis.set_xscale("log") + top_axis.set_yscale("log") + top_axis.set_ylabel("Intensity (arb. units)") + top_axis.set_title( + "Prefit model + solvent trace" if include_solvent else "Prefit model" + ) + top_axis.text( + 0.02, + 0.02, + "\n".join(_prefit_metric_lines(evaluation)), + transform=top_axis.transAxes, + ha="left", + va="bottom", + fontsize=9, + bbox={ + "boxstyle": "round,pad=0.35", + "facecolor": "white", + "edgecolor": "0.6", + "alpha": 0.85, + }, + ) + if plotted_lines: + top_axis.legend( + plotted_lines, + [str(line.get_label()) for line in plotted_lines], + fontsize=8, + loc="best", + framealpha=0.92, + ) + if bottom_axis is not None and evaluation.residuals is not None: + bottom_axis.axhline(0.0, color="0.5", linewidth=1.0) + bottom_axis.plot( + np.asarray(evaluation.q_values, dtype=float), + np.asarray(evaluation.residuals, dtype=float), + color=_PREFIT_WINDOW_RESIDUAL_COLOR, + linewidth=1.2, + ) + bottom_axis.set_xscale("log") + bottom_axis.set_xlabel("q (1/A)") + bottom_axis.set_ylabel("Residual") + bottom_axis.tick_params(labelsize=9) + top_axis.set_xlabel("") + else: + top_axis.set_xlabel("q (1/A)") + top_axis.tick_params(labelsize=9) + _autoscale_visible_axis_limits( + top_axis, + log_x=True, + log_y=True, + ) + if bottom_axis is not None: + _autoscale_visible_axis_limits( + bottom_axis, + log_x=True, + log_y=False, + ) + + +def _draw_model_fit_axis( + axis, + *, + q_values: np.ndarray, + experimental: np.ndarray | None, + model: np.ndarray, + solvent: np.ndarray | None, + structure_factor: np.ndarray | None, + title: str, + metrics_lines: list[str], + show_legend: bool, + compact: bool = False, + dream_output_style: bool = False, + settings: PowerPointExportSettings, +) -> None: + q_values = np.asarray(q_values, dtype=float) + model = np.asarray(model, dtype=float) + legend_lines: list[object] = [] + legend_labels: list[str] = [] + experimental_color = ( + _DREAM_OUTPUT_EXPERIMENTAL_COLOR + if dream_output_style + else settings.experimental_trace_color + ) + model_color = ( + _DREAM_OUTPUT_MODEL_COLOR + if dream_output_style + else settings.model_trace_color + ) + solvent_color = ( + _DREAM_OUTPUT_SOLVENT_COLOR + if dream_output_style + else settings.solvent_trace_color + ) + structure_factor_color = ( + _DREAM_OUTPUT_STRUCTURE_FACTOR_COLOR + if dream_output_style + else settings.structure_factor_color + ) + if experimental is not None: + experimental = np.asarray(experimental, dtype=float) + artist = axis.scatter( + q_values, + experimental, + color=experimental_color, + s=8 if compact else 14, + zorder=3, + label="Experimental", + ) + legend_lines.append(artist) + legend_labels.append("Experimental") + if solvent is not None: + solvent_mask = np.isfinite(solvent) & (solvent > 0.0) + if np.any(solvent_mask): + (solvent_line,) = axis.plot( + q_values[solvent_mask], + solvent[solvent_mask], + color=solvent_color, + linewidth=1.0 if compact else 1.4, + label="Solvent contribution", + ) + legend_lines.append(solvent_line) + legend_labels.append("Solvent") + if structure_factor is not None: + structure_mask = np.isfinite(structure_factor) + if np.any(structure_mask): + twin_axis = axis.twinx() + twin_axis.set_xscale("log") + (structure_line,) = twin_axis.plot( + q_values[structure_mask], + structure_factor[structure_mask], + color=structure_factor_color, + linestyle="--", + linewidth=1.0 if compact else 1.2, + label=( + "Structure factor S(q)" + if dream_output_style + else "Structure factor" + ), + ) + twin_axis.set_ylabel("S(q)", color=structure_factor_color) + twin_axis.tick_params( + axis="y", + colors=structure_factor_color, + ) + twin_axis.spines["right"].set_color(structure_factor_color) + legend_lines.append(structure_line) + legend_labels.append("Structure factor") + (model_line,) = axis.plot( + q_values, + model, + color=model_color, + linewidth=1.5 if compact else 2.0, + label="Model" if not dream_output_style else "Model", + ) + legend_lines.append(model_line) + legend_labels.append("Model") + axis.set_xscale("log") + axis.set_yscale("log") + axis.set_ylabel("Intensity") + axis.set_title(title, fontsize=10 if compact else None) + axis.tick_params(labelsize=8 if compact else 9) + metrics_font = 7 if compact else 9 + axis.text( + 0.02, + 0.02, + "\n".join(metrics_lines), + transform=axis.transAxes, + ha="left", + va="bottom", + fontsize=metrics_font, + bbox={ + "boxstyle": "round,pad=0.35", + "facecolor": "white", + "edgecolor": "#9ca3af", + "alpha": 0.9, + }, + ) + if show_legend and legend_lines: + axis.legend( + legend_lines, + legend_labels, + fontsize=7 if compact else 8, + loc="best", + framealpha=0.92, + ) + axis.set_xlabel("q (1/A)") + + +def _draw_violin_axis( + axis, + view: DreamFilterReportView, + *, + settings: PowerPointExportSettings, + payload: dict[str, object] | None = None, +) -> None: + violin_payload = view.violin_payload if payload is None else payload + samples = np.asarray(violin_payload.get("samples", []), dtype=float) + if samples.ndim == 1: + samples = samples.reshape(1, -1) + display_names = [ + str(label) for label in violin_payload.get("display_names", []) + ] + selected_values = np.asarray( + violin_payload.get("selected_values", []), + dtype=float, + ) + interval_low_values = np.asarray( + violin_payload.get("interval_low_values", []), + dtype=float, + ) + interval_high_values = np.asarray( + violin_payload.get("interval_high_values", []), + dtype=float, + ) + if samples.size == 0 or not display_names: + axis.text( + 0.5, + 0.5, + "No violin data available.", + ha="center", + va="center", + ) + axis.set_axis_off() + return + positions = np.arange(1, len(display_names) + 1) + violin_parts = axis.violinplot( + samples, + positions=positions, + showmedians=True, + ) + body_colors = _gradient_colors( + settings.component_color_map, + len(display_names), + ) + for body, color in zip(violin_parts["bodies"], body_colors, strict=False): + body.set_facecolor(color) + body.set_edgecolor("#374151") + body.set_alpha(0.62) + body.set_linewidth(0.7) + for key in ("cbars", "cmins", "cmaxes"): + artist = violin_parts.get(key) + if artist is not None: + artist.set_color("#4b5563") + artist.set_linewidth(1.0) + median_artist = violin_parts.get("cmedians") + if median_artist is not None: + median_artist.set_color("#111827") + median_artist.set_linewidth(1.2) + axis.vlines( + positions, + interval_low_values, + interval_high_values, + color="#4b5563", + linewidth=1.3, + ) + axis.scatter( + positions, + selected_values, + color=settings.model_trace_color, + s=12, + zorder=3, + ) + axis.set_xticks(positions) + axis.set_xticklabels(display_names, rotation=45, ha="right", fontsize=7) + axis.set_ylabel(str(violin_payload.get("ylabel", "Parameter value"))) + axis.set_title(view.title, fontsize=10) + axis.tick_params(labelsize=8) + y_limits = violin_payload.get("y_limits") + if y_limits is not None: + axis.set_ylim(float(y_limits[0]), float(y_limits[1])) + axis.grid(True, axis="y", alpha=0.15) + + +def _comparison_grid_shape(count: int) -> tuple[int, int]: + if count <= 1: + return (1, 1) + if count <= 2: + return (1, 2) + return (2, 2) + + +def _comparison_violin_payloads( + filter_views: tuple[DreamFilterReportView, ...], + *, + payload_variant: str, +) -> list[dict[str, object]]: + payload_attribute = { + "default": "violin_payload", + "weights": "weights_violin_payload", + "effective_radii": "effective_radii_violin_payload", + }.get(payload_variant) + if payload_attribute is None: + raise ValueError( + "Unknown violin payload variant: " + f"{payload_variant}. Expected 'default', 'weights', " + "or 'effective_radii'." + ) + payloads = [ + _copy_violin_payload(getattr(view, payload_attribute)) + for view in filter_views + ] + if payload_variant in {"weights", "effective_radii"}: + _apply_shared_violin_y_limits(payloads) + return payloads + + +def _copy_violin_payload(payload: dict[str, object]) -> dict[str, object]: + copied: dict[str, object] = {} + for key, value in payload.items(): + if isinstance(value, np.ndarray): + copied[key] = np.asarray(value, dtype=float).copy() + elif isinstance(value, list): + copied[key] = list(value) + elif isinstance(value, tuple): + copied[key] = tuple(value) + else: + copied[key] = value + return copied + + +def _apply_shared_violin_y_limits( + payloads: Sequence[dict[str, object]] +) -> None: + if not payloads: + return + if any(payload.get("y_limits") is not None for payload in payloads): + return + y_limits = _shared_violin_y_limits(payloads) + if y_limits is None: + return + for payload in payloads: + payload["y_limits"] = y_limits + + +def _shared_violin_y_limits( + payloads: Sequence[dict[str, object]], +) -> tuple[float, float] | None: + lower_bounds: list[float] = [] + upper_bounds: list[float] = [] + for payload in payloads: + for key in ( + "samples", + "selected_values", + "interval_low_values", + "interval_high_values", + ): + values = np.asarray(payload.get(key, []), dtype=float) + finite_values = values[np.isfinite(values)] + if finite_values.size == 0: + continue + lower_bounds.append(float(np.min(finite_values))) + upper_bounds.append(float(np.max(finite_values))) + if not lower_bounds or not upper_bounds: + return None + lower = min(lower_bounds) + upper = max(upper_bounds) + if np.isclose(lower, upper): + padding = max(abs(lower) * 0.08, 1.0) + else: + padding = (upper - lower) * 0.08 + shared_lower = lower - padding + shared_upper = upper + padding + if lower >= 0.0 and shared_lower < 0.0: + shared_lower = 0.0 + if shared_upper <= shared_lower: + shared_upper = shared_lower + 1.0 + return (shared_lower, shared_upper) + + +def _apply_axis_scale(axis, *, log_x: bool, log_y: bool) -> None: + axis.set_xscale("log" if log_x else "linear") + axis.set_yscale("log" if log_y else "linear") + + +def _has_prefit_solvent_trace(evaluation: PrefitEvaluation) -> bool: + if evaluation.solvent_contribution is None: + return False + solvent = np.asarray(evaluation.solvent_contribution, dtype=float) + return bool(np.any(np.isfinite(solvent) & (solvent > 0.0))) + + +def _autoscale_visible_axis_limits( + axis, + *, + log_x: bool, + log_y: bool, +) -> None: + try: + axis.relim(visible_only=True) + axis.autoscale_view() + except Exception: + pass + x_values: list[np.ndarray] = [] + y_values: list[np.ndarray] = [] + for line in axis.get_lines(): + if not line.get_visible(): + continue + line_x = np.asarray(line.get_xdata(orig=False), dtype=float) + line_y = np.asarray(line.get_ydata(orig=False), dtype=float) + mask = np.isfinite(line_x) & np.isfinite(line_y) + if log_x: + mask &= line_x > 0.0 + if log_y: + mask &= line_y > 0.0 + if not np.any(mask): + continue + x_values.append(line_x[mask]) + y_values.append(line_y[mask]) + if x_values: + axis.set_xlim( + *_autoscale_axis_limits(np.concatenate(x_values), log_scale=log_x) + ) + if y_values: + axis.set_ylim( + *_autoscale_axis_limits(np.concatenate(y_values), log_scale=log_y) + ) + + +def _autoscale_axis_limits( + values: np.ndarray, + *, + log_scale: bool, +) -> tuple[float, float]: + data = np.asarray(values, dtype=float) + finite = data[np.isfinite(data)] + if finite.size == 0: + return (0.0, 1.0) + lower = float(np.nanmin(finite)) + upper = float(np.nanmax(finite)) + if np.isclose(lower, upper): + if log_scale: + lower = lower / 1.15 + upper = upper * 1.15 + return (lower, upper) + padding = max(abs(lower) * 0.05, 1e-12) + return (lower - padding, upper + padding) + if log_scale: + return (lower / 1.05, upper * 1.05) + padding = 0.05 * (upper - lower) + return (lower - padding, upper + padding) + + +def _component_q_bounds( + plot_data: ReportComponentPlotData, +) -> tuple[float, float] | None: + q_segments: list[np.ndarray] = [] + if plot_data.experimental_q_values is not None: + q_segments.append( + np.asarray(plot_data.experimental_q_values, dtype=float) + ) + if plot_data.solvent_q_values is not None: + q_segments.append(np.asarray(plot_data.solvent_q_values, dtype=float)) + for series in plot_data.component_series: + q_segments.append(np.asarray(series.q_values, dtype=float)) + if not q_segments: + return None + merged = np.concatenate(q_segments) + finite = merged[np.isfinite(merged)] + if finite.size == 0: + return None + return (float(np.nanmin(finite)), float(np.nanmax(finite))) + + +def _autoscale_component_plot_to_model_range( + experimental_axis, + component_axis, + plot_data: ReportComponentPlotData, +) -> None: + model_q_bounds = _component_model_q_bounds(plot_data) + if model_q_bounds is None: + return + q_min, q_max = model_q_bounds + component_axis.set_xlim(q_min, q_max) + if experimental_axis is not None: + experimental_axis.set_xlim(q_min, q_max) + _autoscale_axis_y_for_plot( + experimental_axis, + q_min, + q_max, + log_scale=plot_data.log_y, + ) + _normalize_component_axis_to_experimental( + experimental_axis, + component_axis, + plot_data, + ) + return + _autoscale_axis_y_for_plot( + component_axis, + q_min, + q_max, + log_scale=plot_data.log_y, + ) + + +def _component_model_q_bounds( + plot_data: ReportComponentPlotData, +) -> tuple[float, float] | None: + q_segments = [ + np.asarray(series.q_values, dtype=float) + for series in plot_data.component_series + ] + if not q_segments: + return None + merged = np.concatenate(q_segments) + finite = merged[np.isfinite(merged)] + if finite.size == 0: + return None + return (float(np.nanmin(finite)), float(np.nanmax(finite))) + + +def _autoscale_axis_y_for_plot( + axis, + q_min: float, + q_max: float, + *, + log_scale: bool, +) -> None: + y_segments: list[np.ndarray] = [] + for line in axis.get_lines(): + if not line.get_visible(): + continue + x_data = np.asarray(line.get_xdata(orig=False), dtype=float) + y_data = np.asarray(line.get_ydata(orig=False), dtype=float) + mask = ( + np.isfinite(x_data) + & np.isfinite(y_data) + & (x_data >= q_min) + & (x_data <= q_max) + ) + if log_scale: + mask &= y_data > 0.0 + if np.any(mask): + y_segments.append(y_data[mask]) + if not y_segments: + return + y_values = np.concatenate(y_segments) + y_min = float(np.nanmin(y_values)) + y_max = float(np.nanmax(y_values)) + if np.isclose(y_min, y_max): + padding = max(abs(y_min) * 0.05, 1e-12) + axis.set_ylim(y_min - padding, y_max + padding) + return + if log_scale: + axis.set_ylim(y_min / 1.15, y_max * 1.15) + else: + padding = 0.05 * (y_max - y_min) + axis.set_ylim(y_min - padding, y_max + padding) + + +def _normalize_component_axis_to_experimental( + experimental_axis, + component_axis, + plot_data: ReportComponentPlotData, +) -> None: + if ( + plot_data.experimental_q_values is None + or plot_data.experimental_intensities is None + ): + return + component_values = [ + np.asarray(series.intensities, dtype=float) + for series in plot_data.component_series + ] + if not component_values: + return + component_data = np.concatenate(component_values) + component_data = component_data[np.isfinite(component_data)] + if plot_data.log_y: + component_data = component_data[component_data > 0.0] + if component_data.size == 0: + return + experimental_q_values = np.asarray( + plot_data.experimental_q_values, + dtype=float, + ) + experimental_intensities = np.asarray( + plot_data.experimental_intensities, + dtype=float, + ) + experimental_mask = _selected_q_mask( + experimental_q_values, + plot_data.selected_q_min, + plot_data.selected_q_max, + use_experimental_grid=plot_data.use_experimental_grid, + ) + if not np.any(experimental_mask): + return + filtered_q = experimental_q_values[experimental_mask] + filtered_i = experimental_intensities[experimental_mask] + model_q_bounds = _component_model_q_bounds(plot_data) + if model_q_bounds is None: + return + overlap_mask = (filtered_q >= model_q_bounds[0]) & ( + filtered_q <= model_q_bounds[1] + ) + if np.any(overlap_mask): + filtered_i = filtered_i[overlap_mask] + filtered_i = filtered_i[np.isfinite(filtered_i)] + if plot_data.log_y: + filtered_i = filtered_i[filtered_i > 0.0] + if filtered_i.size == 0: + return + left_limits = experimental_axis.get_ylim() + right_limits = _aligned_y_limits( + left_limits, + float(np.nanmin(filtered_i)), + float(np.nanmax(filtered_i)), + float(np.nanmin(component_data)), + float(np.nanmax(component_data)), + log_scale=plot_data.log_y, + ) + component_axis.set_ylim(right_limits) + + +def _aligned_y_limits( + left_limits: tuple[float, float], + experimental_min: float, + experimental_max: float, + component_min: float, + component_max: float, + *, + log_scale: bool, +) -> tuple[float, float]: + if log_scale: + if ( + min( + left_limits[0], + left_limits[1], + experimental_min, + experimental_max, + component_min, + component_max, + ) + <= 0.0 + ): + log_scale = False + if not log_scale: + left_low, left_high = left_limits + exp_low, exp_high = sorted((experimental_min, experimental_max)) + comp_low, comp_high = sorted((component_min, component_max)) + if np.isclose(left_high, left_low) or np.isclose(exp_high, exp_low): + padding = max(abs(comp_low) * 0.1, 1e-12) + return comp_low - padding, comp_high + padding + p0 = (exp_low - left_low) / (left_high - left_low) + p1 = (exp_high - left_low) / (left_high - left_low) + if np.isclose(p1, p0): + padding = max(abs(comp_low) * 0.1, 1e-12) + return comp_low - padding, comp_high + padding + delta = (comp_high - comp_low) / (p1 - p0) + right_low = comp_low - p0 * delta + right_high = right_low + delta + return right_low, right_high + left_logs = np.log10(np.asarray(left_limits, dtype=float)) + exp_logs = np.log10( + np.asarray(sorted((experimental_min, experimental_max)), dtype=float) + ) + comp_logs = np.log10( + np.asarray(sorted((component_min, component_max)), dtype=float) + ) + if np.isclose(left_logs[1], left_logs[0]) or np.isclose( + exp_logs[1], exp_logs[0] + ): + return component_min / 1.2, component_max * 1.2 + p0 = (exp_logs[0] - left_logs[0]) / (left_logs[1] - left_logs[0]) + p1 = (exp_logs[1] - left_logs[0]) / (left_logs[1] - left_logs[0]) + if np.isclose(p1, p0): + return component_min / 1.2, component_max * 1.2 + delta = (comp_logs[1] - comp_logs[0]) / (p1 - p0) + right_low_log = comp_logs[0] - p0 * delta + right_high_log = right_low_log + delta + return 10**right_low_log, 10**right_high_log + + +def _selected_q_mask( + q_values: np.ndarray, + lower: float | None, + upper: float | None, + *, + use_experimental_grid: bool, +) -> np.ndarray: + values = np.asarray(q_values, dtype=float) + if values.size == 0: + return np.zeros(0, dtype=bool) + if lower is None and upper is None: + return np.ones_like(values, dtype=bool) + lower_bound = lower if lower is not None else float(np.nanmin(values)) + upper_bound = upper if upper is not None else float(np.nanmax(values)) + if lower_bound > upper_bound: + return np.zeros_like(values, dtype=bool) + if use_experimental_grid: + start_index = int(np.argmin(np.abs(values - lower_bound))) + end_index = int(np.argmin(np.abs(values - upper_bound))) + lo_index, hi_index = sorted((start_index, end_index)) + mask = np.zeros_like(values, dtype=bool) + mask[lo_index : hi_index + 1] = True + return mask + return (values >= lower_bound) & (values <= upper_bound) + + +def _gradient_colors(cmap_name: str, count: int) -> list[str]: + try: + cmap = colormaps[cmap_name] + except Exception: + cmap = colormaps["viridis"] + if count <= 1: + return [to_hex(cmap(0.72), keep_alpha=False)] + positions = np.linspace(0.22, 0.9, count) + return [ + to_hex(cmap(float(position)), keep_alpha=False) + for position in positions + ] + + +def _prefit_summary_lines( + evaluation: PrefitEvaluation, + statistics: dict[str, object], +) -> list[str]: + lines = [ + f"Points: {len(np.asarray(evaluation.q_values, dtype=float))}", + ( + "q-range: " + f"{float(np.min(evaluation.q_values)):.6g} to " + f"{float(np.max(evaluation.q_values)):.6g}" + ), + ] + lines.extend(_prefit_metric_lines(evaluation)) + if statistics.get("method"): + lines.append(f"Method: {statistics['method']}") + if statistics.get("nfev") is not None: + lines.append(f"Function evals: {statistics['nfev']}") + if statistics.get("chi_square") is not None: + lines.append(f"Chi^2: {float(statistics['chi_square']):.6g}") + if statistics.get("reduced_chi_square") is not None: + lines.append( + f"Reduced chi^2: {float(statistics['reduced_chi_square']):.6g}" + ) + if statistics.get("r_squared") is not None: + lines.append(f"Saved-fit R^2: {float(statistics['r_squared']):.6g}") + if statistics.get("saved_at"): + lines.append(f"Saved state timestamp: {statistics['saved_at']}") + return lines + + +def _prefit_metric_lines(evaluation: PrefitEvaluation) -> list[str]: + if ( + evaluation.experimental_intensities is None + or evaluation.residuals is None + ): + return [ + "Model Only Mode", + "Experimental fit metrics unavailable", + ] + experimental = np.asarray(evaluation.experimental_intensities, dtype=float) + model = np.asarray(evaluation.model_intensities, dtype=float) + residuals = np.asarray(model - experimental, dtype=float) + rmse = float(np.sqrt(np.mean(residuals**2))) + mean_abs = float(np.mean(np.abs(residuals))) + mean_experimental = float(np.mean(experimental)) + total_sum_squares = float(np.sum((experimental - mean_experimental) ** 2)) + residual_sum_squares = float(np.sum(residuals**2)) + r_squared = ( + float(1.0 - (residual_sum_squares / total_sum_squares)) + if total_sum_squares > 0.0 + else 1.0 + ) + return [ + f"RMSE: {rmse:.6g}", + f"Mean |res|: {mean_abs:.6g}", + f"R^2: {r_squared:.6g}", + ] + + +def _prefit_parameter_rows( + entries: tuple[PrefitParameterEntry, ...], +) -> list[list[str]]: + rows: list[list[str]] = [] + for entry in entries: + rows.append( + [ + str(entry.name), + str(entry.category), + f"{float(entry.value):.6g}", + "Yes" if entry.vary else "No", + f"{float(entry.minimum):.6g}", + f"{float(entry.maximum):.6g}", + str(entry.structure or "-"), + str(entry.motif or "-"), + ] + ) + return rows + + +def _cluster_geometry_rows( + rows: tuple[ClusterGeometryMetadataRow, ...], +) -> list[list[str]]: + table_rows: list[list[str]] = [] + for row in rows: + axes_text = ( + f"{float(row.active_semiaxis_a):.4g} / " + f"{float(row.active_semiaxis_b):.4g} / " + f"{float(row.active_semiaxis_c):.4g}" + ) + table_rows.append( + [ + str(row.cluster_id), + str(row.mapped_parameter or "-"), + str(row.sf_approximation), + f"{float(row.effective_radius):.5g}", + f"{float(row.mean_radius_of_gyration):.5g}", + f"{float(row.mean_max_radius):.5g}", + f"{float(row.anisotropy_metric):.5g}", + axes_text, + ] + ) + return table_rows + + +def _solution_estimate_lines( + estimate: SolutionScatteringEstimate, +) -> list[str]: + lines = [ + "Solution scattering estimator summary", + f"Incident energy: {float(estimate.settings.beam.incident_energy_kev):.6g} keV", + ] + if estimate.number_density_estimate is not None: + lines.extend( + [ + "", + "Number density:", + ( + "Atoms/A^3: " + f"{estimate.number_density_estimate.number_density_a3:.6g}" + ), + ( + "Atoms/cm^3: " + f"{estimate.number_density_estimate.number_density_cm3:.6g}" + ), + ( + "Total atoms: " + f"{estimate.number_density_estimate.total_atoms:.6g}" + ), + ] + ) + if estimate.volume_fraction_estimate is not None: + lines.extend( + [ + "", + "Physical volume fraction:", + ( + "Physical solute-associated volume fraction: " + f"{float(estimate.volume_fraction_estimate.solute_volume_fraction):.6f}" + ), + ( + "Physical solvent-associated volume fraction: " + f"{float(estimate.volume_fraction_estimate.solvent_volume_fraction):.6f}" + ), + ] + ) + if estimate.interaction_contrast_estimate is not None: + lines.extend( + [ + "", + "SAXS-effective interaction ratio:", + ( + "Contrast weight factor: " + f"{float(estimate.interaction_contrast_estimate.contrast_weight_factor):.6g}" + ), + ( + "SAXS-effective solute interaction ratio: " + f"{float(estimate.interaction_contrast_estimate.saxs_effective_solute_interaction_ratio):.6f}" + ), + ( + "SAXS-effective solvent background ratio: " + f"{float(estimate.interaction_contrast_estimate.saxs_effective_solvent_background_ratio):.6f}" + ), + ] + ) + if estimate.attenuation_estimate is not None: + lines.extend( + [ + "", + "Attenuation:", + ( + "Sample transmission: " + f"{float(estimate.attenuation_estimate.sample_transmission):.6f}" + ), + ( + "Neat-solvent transmission: " + f"{float(estimate.attenuation_estimate.neat_solvent_transmission):.6f}" + ), + ( + "Solvent scale factor: " + f"{float(estimate.attenuation_estimate.solvent_scattering_scale_factor):.6f}" + ), + ] + ) + if estimate.interaction_contrast_estimate is not None: + lines.append( + "Single-weight solvent multiplier: " + f"{float(estimate.attenuation_estimate.solvent_scattering_scale_factor * estimate.interaction_contrast_estimate.saxs_effective_solvent_background_ratio):.6f}" + ) + if ( + estimate.attenuation_estimate.neat_solvent_to_sample_ratio + is not None + ): + lines.append( + "Neat-solvent/sample-solvent ratio: " + f"{float(estimate.attenuation_estimate.neat_solvent_to_sample_ratio):.6g}" + ) + if estimate.fluorescence_estimate is not None: + lines.extend( + [ + "", + "Fluorescence:", + ( + "Primary yield proxy: " + f"{float(estimate.fluorescence_estimate.total_primary_detected_yield):.6g}" + ), + ( + "Secondary yield proxy: " + f"{float(estimate.fluorescence_estimate.total_secondary_detected_yield):.6g}" + ), + ] + ) + for line_estimate in estimate.fluorescence_estimate.line_estimates[:5]: + lines.append( + f"{line_estimate.element} {line_estimate.family}: " + f"{float(line_estimate.total_detected_yield):.6g}" + ) + return lines + + +def _dream_settings_lines( + settings: DreamRunSettings, + summary: DreamSummary, +) -> list[str]: + return [ + f"Best-fit method: {summary.bestfit_method}", + ("Posterior filter: " f"{_describe_posterior_filter(settings)}"), + f"Posterior samples kept: {summary.posterior_sample_count}", + ( + "Credible interval: " + f"{summary.credible_interval_low:g} - " + f"{summary.credible_interval_high:g}" + ), + f"MAP location: chain {summary.map_chain + 1}, step {summary.map_step + 1}", + f"nchains: {settings.nchains}", + f"niterations: {settings.niterations}", + f"burn-in (%): {settings.burnin_percent}", + f"nseedchains: {settings.nseedchains}", + f"crossover burn-in: {settings.crossover_burnin}", + f"run label: {settings.run_label}", + f"violin mode: {settings.violin_parameter_mode}", + f"violin sample source: {settings.violin_sample_source}", + f"violin weight order: {settings.violin_weight_order}", + f"violin y-scale: {settings.violin_value_scale_mode}", + ( + "Auto-select best filter after run: " + f"{'on' if settings.auto_select_best_posterior_filter else 'off'}" + ), + ] + + +def _dream_assessment_lines( + assessments: tuple[dict[str, object], ...], + settings: DreamRunSettings, +) -> list[str]: + lines = [ + "Posterior filtering assessment", + ("Active selection: " f"{_describe_posterior_filter(settings)}"), + ] + if not assessments: + lines.append("No saved assessment metrics are available.") + return lines + for assessment in assessments: + lines.extend( + [ + "", + str(assessment.get("description", "Unnamed filter")), + f"RMSE: {float(assessment.get('rmse', 0.0)):.6g}", + ( + "Mean |res|: " + f"{float(assessment.get('mean_abs_residual', 0.0)):.6g}" + ), + f"R^2: {float(assessment.get('r_squared', 0.0)):.6g}", + ( + "Posterior samples: " + f"{int(assessment.get('posterior_sample_count', 0))}" + ), + ] + ) + return lines + + +def _dream_prior_rows( + entries: tuple[DreamParameterEntry, ...], +) -> list[list[str]]: + rows: list[list[str]] = [] + for entry in entries: + dist_params = ", ".join( + f"{key}={float(value):.6g}" + for key, value in sorted(entry.dist_params.items()) + ) + rows.append( + [ + str(entry.param), + str(entry.param_type), + str(entry.structure or "-"), + str(entry.motif or "-"), + f"{float(entry.value):.6g}", + "Yes" if entry.vary else "No", + str(entry.distribution), + dist_params, + ] + ) + return rows + + +def _dream_output_lines( + settings: DreamRunSettings, + summary: DreamSummary, + model_plot: DreamModelPlotData, +) -> list[str]: + lines = [ + f"Template: {model_plot.template_name}", + f"Best-fit method: {settings.bestfit_method}", + ("Posterior filter: " f"{_describe_posterior_filter(settings)}"), + f"Posterior samples kept: {summary.posterior_sample_count}", + f"RMSE: {model_plot.rmse:.6g}", + f"Mean |res|: {model_plot.mean_abs_residual:.6g}", + f"R^2: {model_plot.r_squared:.6g}", + ( + "Credible interval: " + f"{summary.credible_interval_low:g} - " + f"{summary.credible_interval_high:g}" + ), + ( + "Active parameters: " + f"{', '.join(summary.active_parameter_names) or 'None'}" + ), + ] + for index, name in enumerate(summary.full_parameter_names[:10]): + lines.append( + f"{name}: {float(summary.bestfit_params[index]):.6g} " + f"(p{summary.credible_interval_low:g}=" + f"{float(summary.interval_low_values[index]):.6g}, " + f"p{summary.credible_interval_high:g}=" + f"{float(summary.interval_high_values[index]):.6g})" + ) + if len(summary.full_parameter_names) > 10: + lines.append( + f"... {len(summary.full_parameter_names) - 10} additional parameters" + ) + return lines + + +def _describe_posterior_filter(settings: DreamRunSettings) -> str: + if settings.posterior_filter_mode == "top_percent_logp": + return ( + f"top_percent_logp " + f"(top {settings.posterior_top_percent:g}% by log-posterior)" + ) + if settings.posterior_filter_mode == "top_n_logp": + return ( + f"top_n_logp " + f"(top {settings.posterior_top_n} samples by log-posterior)" + ) + return "all_post_burnin" + + +def _manifest_payload( + context: DreamModelReportContext, + *, + figure_paths: list[Path], +) -> dict[str, object]: + return { + "report_type": "dream_model_report_pptx", + "generated_at": context.generated_at.isoformat(), + "project_name": context.project_name, + "project_dir": str(context.project_dir), + "report_path": str(context.output_path), + "asset_dir": str(context.asset_dir), + "dream_run_dir": str(context.dream_summary.run_dir), + "user_q_range": context.user_q_range_text, + "supported_q_range": context.supported_q_range_text, + "q_sampling": context.q_sampling_text, + "template_name": context.template_name, + "template_display_name": context.template_display_name, + "template_module_path": ( + None + if context.template_module_path is None + else str(context.template_module_path) + ), + "model_equation": context.model_equation_text, + "model_reference_lines": list(context.model_reference_lines), + "figure_paths": [str(path) for path in figure_paths], + "prior_histograms": [ + { + "title": request.title, + "mode": request.mode, + "secondary_element": request.secondary_element, + "cmap": request.cmap, + "source": str(request.json_path), + } + for request in context.prior_histograms + ], + "prefit_statistics": { + str(key): value + for key, value in context.prefit_statistics.items() + if not isinstance(value, Path) + }, + "dream_settings": context.dream_settings.to_dict(), + "dream_summary": { + "bestfit_method": context.dream_summary.bestfit_method, + "posterior_filter_mode": context.dream_summary.posterior_filter_mode, + "posterior_sample_count": context.dream_summary.posterior_sample_count, + "credible_interval_low": context.dream_summary.credible_interval_low, + "credible_interval_high": context.dream_summary.credible_interval_high, + }, + "dream_filter_views": [ + { + "title": view.title, + "description": view.description, + "filter_mode": view.filter_mode, + "is_active": view.is_active, + "posterior_sample_count": view.summary.posterior_sample_count, + "rmse": view.model_plot.rmse, + "mean_abs_residual": view.model_plot.mean_abs_residual, + "r_squared": view.model_plot.r_squared, + } + for view in context.dream_filter_views + ], + } + + +def _rgb_color(value: str): + from pptx.dml.color import RGBColor + + red, green, blue = tuple( + int(value[index : index + 2], 16) for index in (1, 3, 5) + ) + return RGBColor(red, green, blue) + + +def _slugify(value: str) -> str: + safe = "".join( + character.lower() if character.isalnum() else "_" + for character in value + ) + while "__" in safe: + safe = safe.replace("__", "_") + return safe.strip("_") or "figure" + + +__all__ = [ + "DreamFilterReportView", + "DreamModelReportContext", + "ModelReportExportResult", + "PriorHistogramRequest", + "ReportComponentPlotData", + "ReportComponentSeries", + "export_dream_model_report_pptx", +] diff --git a/src/saxshell/saxs/prefit/__init__.py b/src/saxshell/saxs/prefit/__init__.py index 9b02812..a53b9a9 100644 --- a/src/saxshell/saxs/prefit/__init__.py +++ b/src/saxshell/saxs/prefit/__init__.py @@ -17,6 +17,8 @@ PrefitSavedState, PrefitScaleRecommendation, SAXSPrefitWorkflow, + normalize_prefit_parameter_expression, + resolve_prefit_parameter_entries, ) __all__ = [ @@ -35,5 +37,7 @@ "compute_cluster_geometry_metadata", "copy_cluster_geometry_rows", "load_cluster_geometry_metadata", + "normalize_prefit_parameter_expression", + "resolve_prefit_parameter_entries", "save_cluster_geometry_metadata", ] diff --git a/src/saxshell/saxs/prefit/workflow.py b/src/saxshell/saxs/prefit/workflow.py index e1ad44a..f6247e5 100644 --- a/src/saxshell/saxs/prefit/workflow.py +++ b/src/saxshell/saxs/prefit/workflow.py @@ -29,6 +29,7 @@ validate_positive_cluster_geometry_table, ) from saxshell.saxs.project_manager import ( + ProjectSettings, SAXSProjectManager, build_project_paths, load_built_component_q_range, @@ -67,6 +68,13 @@ def _optional_int(value: object) -> int | None: "solvent_fraction", "solvent_volume_fraction", ) +SOLVENT_WEIGHT_PARAMETER_NAMES = ( + "solv_w", + "solvent_scale", +) +PREFIT_MODEL_POSITIVE_FLOOR_RELATIVE = 1e-9 +PREFIT_MODEL_NEGATIVE_PENALTY_MULTIPLIER = 25.0 +PREFIT_MODEL_NONFINITE_PENALTY_MULTIPLIER = 100.0 def q_range_boundary_tolerance( @@ -117,6 +125,8 @@ class PrefitParameterEntry: minimum: float maximum: float category: str + value_expression: str | None = None + initial_value_expression: str | None = None def to_dict(self) -> dict[str, object]: return asdict(self) @@ -128,6 +138,15 @@ def from_dict(cls, payload: dict[str, object]) -> "PrefitParameterEntry": motif=str(payload.get("motif", "")), name=str(payload.get("name", "")), value=float(payload.get("value", 0.0)), + value_expression=_optional_str( + payload.get("value_expression", payload.get("expression")) + ), + initial_value_expression=_optional_str( + payload.get( + "initial_value_expression", + payload.get("initial_expression"), + ) + ), vary=bool(payload.get("vary", True)), minimum=float(payload.get("minimum", payload.get("min", 0.0))), maximum=float(payload.get("maximum", payload.get("max", 0.0))), @@ -135,14 +154,253 @@ def from_dict(cls, payload: dict[str, object]) -> "PrefitParameterEntry": ) +def _parameter_value_expression( + entry: PrefitParameterEntry, +) -> str | None: + return _optional_str(entry.value_expression) + + +def _parameter_initial_value_expression( + entry: PrefitParameterEntry, +) -> str | None: + return _optional_str(entry.initial_value_expression) + + +def normalize_prefit_parameter_expression(expression: str) -> str: + normalized = str(expression).strip() + if not normalized: + raise ValueError("Linked parameter expressions cannot be empty.") + if normalized[0] in {"*", "/"}: + return f"1{normalized}" + if normalized[0] == "+": + return f"0{normalized}" + return normalized + + +def build_prefit_lmfit_parameters( + entries: list[PrefitParameterEntry], +) -> tuple[Parameters, list[PrefitParameterEntry]]: + working_entries = [ + PrefitParameterEntry.from_dict(entry.to_dict()) for entry in entries + ] + seed_params = Parameters() + runtime_expression_entries: list[PrefitParameterEntry] = [] + seed_expression_entries: list[PrefitParameterEntry] = [] + + for entry in working_entries: + value = float(entry.value) + seed_params.add( + entry.name, + value=value, + vary=False, + min=-np.inf, + max=np.inf, + ) + if _parameter_value_expression(entry) is not None: + runtime_expression_entries.append(entry) + if _parameter_initial_value_expression(entry) is not None: + seed_expression_entries.append(entry) + + for entry in runtime_expression_entries + seed_expression_entries: + raw_expression = _parameter_value_expression( + entry + ) or _parameter_initial_value_expression(entry) + if raw_expression is None: + continue + normalized_expression = normalize_prefit_parameter_expression( + raw_expression + ) + try: + seed_params[entry.name].set( + expr=normalized_expression, + vary=False, + min=-np.inf, + max=np.inf, + ) + except Exception as exc: + raise ValueError( + "Invalid linked parameter expression for " + f"{entry.name}: {raw_expression}" + ) from exc + + for entry in runtime_expression_entries + seed_expression_entries: + raw_expression = _parameter_value_expression( + entry + ) or _parameter_initial_value_expression(entry) + if raw_expression is None: + continue + try: + entry.value = float(seed_params[entry.name].value) + except Exception as exc: + raise ValueError( + "Invalid linked parameter expression for " + f"{entry.name}: {raw_expression}" + ) from exc + + lmfit_params = Parameters() + for entry in working_entries: + lower = float(entry.minimum) + upper = float(entry.maximum) + value = float(entry.value) + if lower > upper: + lower, upper = upper, lower + if value < lower: + lower = value + if value > upper: + upper = value + entry.minimum = lower + entry.maximum = upper + lmfit_params.add( + entry.name, + value=value, + vary=bool(entry.vary), + min=lower, + max=upper, + ) + + for entry in runtime_expression_entries: + raw_expression = _parameter_value_expression(entry) + if raw_expression is None: + continue + normalized_expression = normalize_prefit_parameter_expression( + raw_expression + ) + try: + lmfit_params[entry.name].set( + expr=normalized_expression, + vary=False, + min=-np.inf, + max=np.inf, + ) + except Exception as exc: + raise ValueError( + "Invalid linked parameter expression for " + f"{entry.name}: {raw_expression}" + ) from exc + + for entry in runtime_expression_entries: + raw_expression = _parameter_value_expression(entry) + if raw_expression is None: + continue + try: + entry.value = float(lmfit_params[entry.name].value) + except Exception as exc: + raise ValueError( + "Invalid linked parameter expression for " + f"{entry.name}: {raw_expression}" + ) from exc + entry.vary = False + + values = lmfit_params.valuesdict() + for entry in working_entries: + if _parameter_value_expression(entry) is not None: + entry.value = float(values[entry.name]) + entry.vary = False + continue + parameter = lmfit_params[entry.name] + entry.value = float(parameter.value) + entry.vary = bool(parameter.vary) + entry.minimum = float(parameter.min) + entry.maximum = float(parameter.max) + return lmfit_params, working_entries + + +def resolve_prefit_parameter_entries( + entries: list[PrefitParameterEntry], +) -> list[PrefitParameterEntry]: + _params, resolved_entries = build_prefit_lmfit_parameters(entries) + return resolved_entries + + +def constrained_prefit_residuals( + experimental: np.ndarray, + model: np.ndarray, +) -> np.ndarray: + experimental_values = np.asarray(experimental, dtype=float) + model_values = np.asarray(model, dtype=float) + residuals = model_values - experimental_values + penalty = np.zeros_like(residuals, dtype=float) + + finite_experimental = experimental_values[np.isfinite(experimental_values)] + finite_model = model_values[np.isfinite(model_values)] + penalty_scale_candidates = [1e-12] + if finite_experimental.size: + penalty_scale_candidates.append( + float(np.max(np.abs(finite_experimental))) + ) + if finite_model.size: + penalty_scale_candidates.append(float(np.max(np.abs(finite_model)))) + penalty_scale = max(penalty_scale_candidates) + positive_floor = max( + penalty_scale * PREFIT_MODEL_POSITIVE_FLOOR_RELATIVE, + 1e-15, + ) + + invalid_mask = ~np.isfinite(model_values) + if np.any(invalid_mask): + penalty[invalid_mask] = ( + PREFIT_MODEL_NONFINITE_PENALTY_MULTIPLIER * penalty_scale + ) + + non_positive_mask = np.isfinite(model_values) & ( + model_values <= positive_floor + ) + if np.any(non_positive_mask): + reference = np.maximum( + np.abs(experimental_values[non_positive_mask]), + positive_floor, + ) + deficit = ( + positive_floor - model_values[non_positive_mask] + ) / reference + penalty[non_positive_mask] = ( + PREFIT_MODEL_NEGATIVE_PENALTY_MULTIPLIER + * penalty_scale + * (1.0 + deficit) + ) + + return np.concatenate([residuals, penalty]) + + +def validate_prefit_parameter_identifiability( + entries: list[PrefitParameterEntry], +) -> None: + phi_solute_entry = next( + (entry for entry in entries if entry.name == "phi_solute"), + None, + ) + solvent_entry = next( + ( + entry + for entry in entries + if entry.name in SOLVENT_WEIGHT_PARAMETER_NAMES + ), + None, + ) + if phi_solute_entry is None or solvent_entry is None: + return + if bool(phi_solute_entry.vary) and bool(solvent_entry.vary): + raise ValueError( + "phi_solute and " + f"{solvent_entry.name} cannot both vary during fitting. " + "Their product controls the solvent subtraction term, so they " + "must be fixed by prior estimates or only one may vary at a time." + ) + + @dataclass(slots=True) class PrefitEvaluation: q_values: np.ndarray - experimental_intensities: np.ndarray + experimental_intensities: np.ndarray | None model_intensities: np.ndarray - residuals: np.ndarray + residuals: np.ndarray | None solvent_intensities: np.ndarray | None = None solvent_contribution: np.ndarray | None = None + structure_factor_trace: np.ndarray | None = None + + @property + def is_model_only(self) -> bool: + return self.experimental_intensities is None @dataclass(slots=True) @@ -166,6 +424,10 @@ class PrefitScaleRecommendation: recommended_maximum: float adjustment_factor: float points_used: int + current_offset: float | None = None + recommended_offset: float | None = None + recommended_offset_minimum: float | None = None + recommended_offset_maximum: float | None = None @dataclass(slots=True) @@ -195,9 +457,7 @@ def __init__( self.project_manager = SAXSProjectManager() self.settings = self.project_manager.load_project(project_dir) self.paths = build_project_paths(self.settings.project_dir) - self.experimental_data = self.project_manager.load_experimental_data( - self.settings - ) + self.experimental_data = self._load_experimental_trace() self.template_dir = ( Path(template_dir).expanduser().resolve() if template_dir is not None @@ -230,10 +490,58 @@ def available_templates( ) -> list[TemplateSpec]: return list_template_specs(template_dir) + def has_experimental_data(self) -> bool: + return self.experimental_data is not None + + def can_run_prefit(self) -> bool: + return ( + self.has_experimental_data() and not self.settings.model_only_mode + ) + + def set_model_only_mode(self, enabled: bool) -> None: + self.settings.model_only_mode = bool(enabled) + if self.settings.model_only_mode: + self.settings.use_experimental_grid = False + if self.settings.q_points is None or self.settings.q_points <= 1: + self.settings.q_points = 500 + self.project_manager.save_project(self.settings) + self.experimental_data = self._load_experimental_trace() + self.solvent_data = self._load_solvent_trace() + + def apply_project_settings( + self, + settings: ProjectSettings, + ) -> None: + incoming_settings = ProjectSettings.from_dict(settings.to_dict()) + if incoming_settings.resolved_project_dir != self.paths.project_dir: + raise ValueError( + "Cannot apply project settings from a different SAXS project." + ) + selected_template = ( + str(incoming_settings.selected_model_template or "").strip() + or self.template_spec.name + ) + if selected_template != self.template_spec.name: + raise ValueError( + "The active Prefit template changed. Reload the project " + "workflows instead of applying settings in place." + ) + self.settings = incoming_settings + self.paths = build_project_paths(self.settings.project_dir) + self.component_map_path = self.paths.project_dir / "md_saxs_map.json" + self.prior_weights_path = ( + self.paths.project_dir / "md_prior_weights.json" + ) + self.cluster_geometry_metadata_path = ( + self.paths.cluster_geometry_metadata_file + ) + self.experimental_data = self._load_experimental_trace() + self.solvent_data = self._load_solvent_trace() + def load_parameter_entries(self) -> list[PrefitParameterEntry]: best_entries = self.load_best_prefit_entries() if best_entries is not None: - return best_entries + return self._apply_parameter_constraints(best_entries) state_path = self.paths.prefit_dir / "prefit_state.json" if state_path.is_file(): payload = json.loads(state_path.read_text(encoding="utf-8")) @@ -242,7 +550,7 @@ def load_parameter_entries(self) -> list[PrefitParameterEntry]: PrefitParameterEntry.from_dict(entry) for entry in entries ] if self._has_matching_entry_signature(parsed_entries): - return parsed_entries + return self._apply_parameter_constraints(parsed_entries) if parsed_entries: return self._merge_parameter_entries( parsed_entries, @@ -257,12 +565,19 @@ def evaluate( entries = parameter_entries or self.parameter_entries q_values = self._component_q_values() model_data = self._model_data_for_q_values(q_values) - experimental = np.interp( - q_values, - self.experimental_data.q_values, - self.experimental_data.intensities, + experimental = ( + np.interp( + q_values, + self.experimental_data.q_values, + self.experimental_data.intensities, + ) + if self.experimental_data is not None + else None + ) + _lmfit_params, resolved_entries = build_prefit_lmfit_parameters( + entries ) - params = {entry.name: float(entry.value) for entry in entries} + params = {entry.name: float(entry.value) for entry in resolved_entries} solvent_data = ( self._solvent_trace_for_q_values(q_values) if self.solvent_data is not None @@ -288,14 +603,30 @@ def evaluate( params=params, extra_inputs=extra_inputs, ) - residuals = model_intensities - experimental + structure_factor_trace = self._evaluate_structure_factor_trace( + q_values, + solvent_data=solvent_data, + model_data=model_data, + params=params, + extra_inputs=extra_inputs, + ) + residuals = ( + np.asarray(model_intensities, dtype=float) - experimental + if experimental is not None + else None + ) return PrefitEvaluation( q_values=q_values, experimental_intensities=experimental, model_intensities=np.asarray(model_intensities, dtype=float), - residuals=np.asarray(residuals, dtype=float), + residuals=( + np.asarray(residuals, dtype=float) + if residuals is not None + else None + ), solvent_intensities=solvent_intensities, solvent_contribution=solvent_contribution, + structure_factor_trace=structure_factor_trace, ) def run_fit( @@ -305,11 +636,21 @@ def run_fit( method: str = "leastsq", max_nfev: int = 10000, ) -> PrefitFitResult: - entries = [ - PrefitParameterEntry.from_dict(entry.to_dict()) - for entry in (parameter_entries or self.parameter_entries) - ] - self._ensure_entry_bounds_include_current_values(entries) + if self.settings.model_only_mode: + raise ValueError( + "Prefit is disabled in Model Only Mode. Disable Model Only " + "Mode and load experimental SAXS data to run a fit." + ) + if self.experimental_data is None: + raise ValueError( + "Prefit requires experimental SAXS data before a fit can be run." + ) + validate_prefit_parameter_identifiability( + parameter_entries or self.parameter_entries + ) + lmfit_params, entries = build_prefit_lmfit_parameters( + parameter_entries or self.parameter_entries + ) q_values = self._component_q_values() model_data = self._model_data_for_q_values(q_values) experimental = np.interp( @@ -324,15 +665,6 @@ def run_fit( ) lmfit_model = self._lmfit_model_function() extra_inputs = self._lmfit_extra_inputs() - lmfit_params = Parameters() - for entry in entries: - lmfit_params.add( - entry.name, - value=float(entry.value), - vary=bool(entry.vary), - min=float(entry.minimum), - max=float(entry.maximum), - ) def objective(active_params: Parameters) -> np.ndarray: params = active_params.valuesdict() @@ -343,7 +675,10 @@ def objective(active_params: Parameters) -> np.ndarray: *extra_inputs, **params, ) - return np.asarray(model, dtype=float) - experimental + return constrained_prefit_residuals( + experimental, + np.asarray(model, dtype=float), + ) result = minimize( objective, @@ -355,11 +690,21 @@ def objective(active_params: Parameters) -> np.ndarray: for entry in entries: fitted = result.params[entry.name] entry.value = float(fitted.value) + if _parameter_value_expression(entry) is not None: + entry.vary = False + continue entry.vary = bool(fitted.vary) entry.minimum = float(fitted.min) entry.maximum = float(fitted.max) evaluation = self.evaluate(entries) + if ( + evaluation.residuals is None + or evaluation.experimental_intensities is None + ): + raise ValueError( + "Prefit fit statistics are unavailable without experimental SAXS data." + ) chi_square = float(np.sum(evaluation.residuals**2)) dof = max( len(evaluation.q_values) @@ -472,6 +817,12 @@ def save_fit( "min": entry.minimum, "max": entry.maximum, } + expression = _parameter_value_expression(entry) + if expression is not None: + meta["expression"] = expression + initial_expression = _parameter_initial_value_expression(entry) + if initial_expression is not None: + meta["initial_expression"] = initial_expression if entry.category == "weight": weights_payload.append( { @@ -522,25 +873,38 @@ def save_fit( ) curve_path = self.paths.prefit_dir / "latest_prefit_curve.txt" - curve_data = np.column_stack( - [ - evaluation.q_values, - evaluation.experimental_intensities, - evaluation.model_intensities, - evaluation.residuals, - ] - ) + if ( + evaluation.experimental_intensities is not None + and evaluation.residuals is not None + ): + curve_data = np.column_stack( + [ + evaluation.q_values, + evaluation.experimental_intensities, + evaluation.model_intensities, + evaluation.residuals, + ] + ) + curve_header = "q experimental_intensity model_intensity residual" + else: + curve_data = np.column_stack( + [ + evaluation.q_values, + evaluation.model_intensities, + ] + ) + curve_header = "q model_intensity" np.savetxt( curve_path, curve_data, - header="q experimental_intensity model_intensity residual", + header=curve_header, comments="", ) snapshot_curve_path = snapshot_dir / "prefit_curve.txt" np.savetxt( snapshot_curve_path, curve_data, - header="q experimental_intensity model_intensity residual", + header=curve_header, comments="", ) @@ -582,16 +946,21 @@ def load_template_reset_entries(self) -> list[PrefitParameterEntry]: self.settings.template_reset_parameter_entries, ) if stored_entries is not None: - return stored_entries - return self._copy_entries(self._template_default_entries) + return self._apply_parameter_constraints(stored_entries) + return self._apply_parameter_constraints( + self._copy_entries(self._template_default_entries) + ) def load_best_prefit_entries( self, ) -> list[PrefitParameterEntry] | None: - return self._entries_from_project_payload( + entries = self._entries_from_project_payload( self.settings.best_prefit_template, self.settings.best_prefit_parameter_entries, ) + if entries is None: + return None + return self._apply_parameter_constraints(entries) def save_best_prefit_entries( self, @@ -650,7 +1019,12 @@ def load_saved_state(self, state_name: str) -> PrefitSavedState: path=state_dir, saved_at=str(payload.get("saved_at", state_dir.name)), template_name=str(payload.get("template_name", "")).strip(), - parameter_entries=parameter_entries, + parameter_entries=self._apply_parameter_constraints( + parameter_entries, + default_entries=self._build_default_parameter_entries( + cluster_geometry_table=cluster_geometry_table, + ), + ), cluster_geometry_table=cluster_geometry_table, method=_optional_str(run_settings.get("method")), max_nfev=_optional_int(run_settings.get("max_nfev")), @@ -667,6 +1041,14 @@ def recommend_scale_settings( *, span_factor: float = 10.0, ) -> PrefitScaleRecommendation: + if self.settings.model_only_mode: + raise ValueError( + "Scale recommendations are unavailable in Model Only Mode." + ) + if self.experimental_data is None: + raise ValueError( + "Scale recommendations require experimental SAXS data." + ) entries = parameter_entries or self.parameter_entries scale_entry = next( (entry for entry in entries if entry.name == "scale"), @@ -676,27 +1058,57 @@ def recommend_scale_settings( raise ValueError( "The current SAXS template does not define a scale parameter." ) - evaluation = self.evaluate(entries) + if _parameter_value_expression(scale_entry) is not None: + raise ValueError( + "Scale recommendations are unavailable when scale is linked " + "to another parameter expression." + ) + offset_entry = next( + (entry for entry in entries if entry.name == "offset"), + None, + ) + if ( + offset_entry is not None + and _parameter_value_expression(offset_entry) is not None + ): + raise ValueError( + "Scale recommendations are unavailable when offset is linked " + "to another parameter expression." + ) + current_offset = ( + float(offset_entry.value) if offset_entry is not None else None + ) + recommended_entries = self._copy_entries(entries) + for entry in recommended_entries: + if entry.name == "scale": + entry.value = 1.0 + elif entry.name == "offset": + entry.value = 0.0 + evaluation = self.evaluate(recommended_entries) offset_value = next( ( float(entry.value) - for entry in entries + for entry in recommended_entries if entry.name == "offset" ), 0.0, ) + solvent_contribution = ( + np.asarray(evaluation.solvent_contribution, dtype=float) + if evaluation.solvent_contribution is not None + else np.zeros_like(evaluation.model_intensities, dtype=float) + ) target = np.asarray( - evaluation.experimental_intensities - offset_value, + evaluation.experimental_intensities + - offset_value + - solvent_contribution, dtype=float, ) model = np.asarray( - evaluation.model_intensities - offset_value, + evaluation.model_intensities - offset_value - solvent_contribution, dtype=float, ) - mask = np.isfinite(target) & np.isfinite(model) & (np.abs(model) > 0.0) - positive_mask = mask & (target > 0.0) & (model > 0.0) - if np.count_nonzero(positive_mask) >= 3: - mask = positive_mask + mask = np.isfinite(target) & np.isfinite(model) if not np.any(mask): raise ValueError( "A scale recommendation is not available because the current " @@ -705,49 +1117,112 @@ def recommend_scale_settings( ) masked_target = np.asarray(target[mask], dtype=float) masked_model = np.asarray(model[mask], dtype=float) - centered_target = masked_target - float(np.nanmin(masked_target)) - centered_model = masked_model - float(np.nanmin(masked_model)) - numerator = float(np.dot(centered_model, centered_target)) - denominator = float(np.dot(centered_model, centered_model)) - adjustment_factor = ( - numerator / denominator if denominator > 0.0 else float("nan") - ) - if not np.isfinite(adjustment_factor) or adjustment_factor <= 0.0: - adjustment_factor = float( - np.median( - np.abs(centered_target) - / np.maximum(np.abs(centered_model), 1e-30) - ) + if np.count_nonzero(np.isfinite(masked_model)) < 2: + raise ValueError( + "A scale recommendation requires at least two finite SAXS " + "model points." ) - if not np.isfinite(adjustment_factor) or adjustment_factor <= 0.0: - target_span = float( - np.nanmax(masked_target) - np.nanmin(masked_target) + centered_model = masked_model - float(np.nanmean(masked_model)) + centered_target = masked_target - float(np.nanmean(masked_target)) + recommended_offset: float | None = None + if offset_entry is not None: + design_matrix = np.column_stack( + [ + masked_model, + np.ones_like(masked_model, dtype=float), + ] + ) + coefficients, *_ = np.linalg.lstsq( + design_matrix, + masked_target, + rcond=None, ) - model_span = float( - np.nanmax(masked_model) - np.nanmin(masked_model) + recommended_scale = float(coefficients[0]) + recommended_offset = float(coefficients[1]) + else: + numerator = float(np.dot(masked_model, masked_target)) + denominator = float(np.dot(masked_model, masked_model)) + recommended_scale = ( + numerator / denominator if denominator > 0.0 else float("nan") + ) + if not np.isfinite(recommended_scale) or recommended_scale <= 0.0: + numerator = float(np.dot(centered_model, centered_target)) + denominator = float(np.dot(centered_model, centered_model)) + adjustment_factor = ( + numerator / denominator if denominator > 0.0 else float("nan") ) - if model_span > 0.0 and target_span > 0.0: - adjustment_factor = target_span / model_span - if not np.isfinite(adjustment_factor) or adjustment_factor <= 0.0: + if not np.isfinite(adjustment_factor) or adjustment_factor <= 0.0: + adjustment_factor = float( + np.median( + np.abs(masked_target) + / np.maximum(np.abs(masked_model), 1e-30) + ) + ) + if not np.isfinite(adjustment_factor) or adjustment_factor <= 0.0: + target_span = float( + np.nanmax(masked_target) - np.nanmin(masked_target) + ) + model_span = float( + np.nanmax(masked_model) - np.nanmin(masked_model) + ) + if model_span > 0.0 and target_span > 0.0: + adjustment_factor = target_span / model_span + recommended_scale = float(adjustment_factor) + if offset_entry is not None: + recommended_offset = float( + np.nanmean( + masked_target - recommended_scale * masked_model + ) + ) + if not np.isfinite(recommended_scale) or recommended_scale <= 0.0: raise ValueError( "A positive scale recommendation could not be estimated from " "the current model and experimental traces." ) - current_scale = float(scale_entry.value) + raw_current_scale = float(scale_entry.value) + current_scale = raw_current_scale if current_scale <= 0.0: current_scale = max(float(scale_entry.maximum), 1.0) / span_factor - recommended_scale = max(current_scale * adjustment_factor, 1e-12) + recommended_scale = max(float(recommended_scale), 1e-12) + adjustment_factor = ( + recommended_scale / current_scale + if current_scale > 0.0 + else float("nan") + ) recommended_minimum = max(recommended_scale / span_factor, 1e-12) recommended_maximum = max( recommended_scale * span_factor, recommended_scale * 1.5, float(scale_entry.maximum), ) + recommended_offset_minimum: float | None = None + recommended_offset_maximum: float | None = None + if offset_entry is not None and recommended_offset is not None: + target_span = float( + np.nanmax(masked_target) - np.nanmin(masked_target) + ) + offset_padding = max( + target_span / span_factor, + abs(float(recommended_offset)) / span_factor, + 1e-12, + ) + recommended_offset_minimum = min( + float(offset_entry.minimum), + float(recommended_offset) - offset_padding, + ) + recommended_offset_maximum = max( + float(offset_entry.maximum), + float(recommended_offset) + offset_padding, + ) return PrefitScaleRecommendation( - current_scale=float(scale_entry.value), + current_scale=raw_current_scale, recommended_scale=recommended_scale, recommended_minimum=min(recommended_minimum, recommended_scale), recommended_maximum=max(recommended_maximum, recommended_scale), + current_offset=current_offset, + recommended_offset=recommended_offset, + recommended_offset_minimum=recommended_offset_minimum, + recommended_offset_maximum=recommended_offset_maximum, adjustment_factor=adjustment_factor, points_used=int(np.count_nonzero(mask)), ) @@ -769,6 +1244,17 @@ def volume_fraction_estimator_target(self) -> tuple[str, str] | None: def supports_volume_fraction_estimator(self) -> bool: return self.volume_fraction_estimator_target() is not None + def solvent_weight_estimator_target(self) -> str | None: + parameter_names = { + str(parameter.name).strip() + for parameter in self.template_spec.parameters + if str(parameter.name).strip() + } + for candidate in SOLVENT_WEIGHT_PARAMETER_NAMES: + if candidate in parameter_names: + return candidate + return None + def supports_cluster_geometry_metadata(self) -> bool: return bool(self.template_spec.cluster_geometry_support.supported) @@ -1213,9 +1699,24 @@ def _merge_parameter_entries( minimum=float(existing_entry.minimum), maximum=float(existing_entry.maximum), category=default_entry.category, + value_expression=_parameter_value_expression( + existing_entry + ), + initial_value_expression=( + _parameter_initial_value_expression(existing_entry) + ), ) ) - return merged_entries + defaults_by_name = { + entry.name: entry for entry in default_entries if entry.name + } + return [ + SAXSPrefitWorkflow._apply_parameter_entry_constraints( + entry, + defaults_by_name.get(entry.name), + ) + for entry in merged_entries + ] def _ensure_project_parameter_presets(self) -> None: dirty = False @@ -1255,7 +1756,51 @@ def _entries_from_project_payload( entries = [PrefitParameterEntry.from_dict(entry) for entry in payload] if not self._has_matching_entry_signature(entries): return None - return entries + return self._apply_parameter_constraints(entries) + + def _apply_parameter_constraints( + self, + entries: list[PrefitParameterEntry], + *, + default_entries: list[PrefitParameterEntry] | None = None, + ) -> list[PrefitParameterEntry]: + defaults = default_entries or self._template_default_entries + defaults_by_name = { + entry.name: entry for entry in defaults if entry.name + } + return [ + self._apply_parameter_entry_constraints( + entry, + defaults_by_name.get(entry.name), + ) + for entry in entries + ] + + @staticmethod + def _apply_parameter_entry_constraints( + entry: PrefitParameterEntry, + default_entry: PrefitParameterEntry | None = None, + ) -> PrefitParameterEntry: + constrained = PrefitParameterEntry.from_dict(entry.to_dict()) + if constrained.name not in SOLVENT_WEIGHT_PARAMETER_NAMES: + return constrained + minimum = 0.0 + maximum = 1.0 + if default_entry is not None: + minimum = max(minimum, float(default_entry.minimum)) + maximum = min(maximum, float(default_entry.maximum)) + bounded_minimum = max(float(constrained.minimum), minimum) + bounded_maximum = min(float(constrained.maximum), maximum) + if bounded_maximum < bounded_minimum: + bounded_minimum = minimum + bounded_maximum = maximum + constrained.minimum = bounded_minimum + constrained.maximum = bounded_maximum + constrained.value = min( + max(float(constrained.value), constrained.minimum), + constrained.maximum, + ) + return constrained def _has_matching_entry_signature( self, @@ -1336,7 +1881,17 @@ def _load_components(self) -> list[PrefitComponent]: ) return components + def _load_experimental_trace(self): + if self.settings.model_only_mode: + return None + try: + return self.project_manager.load_experimental_data(self.settings) + except Exception: + return None + def _load_solvent_trace(self) -> np.ndarray | None: + if self.settings.model_only_mode: + return None q_values = self._component_q_values_from_candidates() solvent_summary = self.project_manager.load_solvent_data(self.settings) if solvent_summary is not None: @@ -1363,7 +1918,27 @@ def _component_q_values(self) -> np.ndarray: def _supported_component_q_range(self) -> tuple[float, float]: supported = load_built_component_q_range(self.paths.project_dir) if supported is None: - q_values = np.asarray(self.experimental_data.q_values, dtype=float) + components = getattr(self, "components", []) + if self.experimental_data is not None: + q_values = np.asarray( + self.experimental_data.q_values, dtype=float + ) + elif components: + q_values = np.asarray(components[0].q_values, dtype=float) + else: + source = np.asarray( + [ + value + for value in (self.settings.q_min, self.settings.q_max) + if value is not None + ], + dtype=float, + ) + if source.size == 0: + raise ValueError( + "No q-range is available for the current SAXS project." + ) + q_values = source return (float(np.min(q_values)), float(np.max(q_values))) return supported @@ -1431,10 +2006,16 @@ def _component_q_values_from_candidates( self.paths.scattering_components_dir.glob("*.txt") ) if not component_files: - source_q_values = np.asarray( - self.experimental_data.q_values, - dtype=float, - ) + if self.experimental_data is not None: + source_q_values = np.asarray( + self.experimental_data.q_values, + dtype=float, + ) + else: + raise ValueError( + "No SAXS component q-grid is available yet. Build the " + "SAXS components in Project Setup before previewing the model." + ) else: raw_data = np.loadtxt(component_files[0], comments="#") source_q_values = np.asarray(raw_data[:, 0], dtype=float) @@ -1606,6 +2187,39 @@ def _evaluate_solvent_contribution( ) return np.asarray(contribution, dtype=float) + def _evaluate_structure_factor_trace( + self, + q_values: np.ndarray, + *, + solvent_data: np.ndarray, + model_data: list[np.ndarray], + params: dict[str, float], + extra_inputs: list[np.ndarray], + ) -> np.ndarray | None: + structure_factor_function = getattr( + self.template_module, + "structure_factor_profile", + None, + ) + if structure_factor_function is None: + return None + try: + structure_factor = structure_factor_function( + q_values, + np.asarray(solvent_data, dtype=float), + model_data, + *extra_inputs, + **params, + ) + except Exception: + return None + structure_factor_array = np.asarray(structure_factor, dtype=float) + if structure_factor_array.shape != np.asarray(q_values).shape: + return None + if not np.all(np.isfinite(structure_factor_array)): + return None + return structure_factor_array + def _lmfit_extra_inputs(self) -> list[np.ndarray]: runtime_inputs = self.template_runtime_inputs() return [ @@ -1801,6 +2415,17 @@ def _build_report_text( f" q points: {len(evaluation.q_values)}", ] ) + if fit_result is None and ( + evaluation.experimental_intensities is None + or evaluation.residuals is None + ): + lines.extend( + [ + " mode: model_only", + " experimental_data: unavailable", + " fit_metrics: unavailable", + ] + ) if fit_result is not None: lines.extend( [ diff --git a/src/saxshell/saxs/project_manager/__init__.py b/src/saxshell/saxs/project_manager/__init__.py index e92d106..12799a1 100644 --- a/src/saxshell/saxs/project_manager/__init__.py +++ b/src/saxshell/saxs/project_manager/__init__.py @@ -10,6 +10,7 @@ ClusterImportResult, DreamBestFitSelection, ExperimentalDataSummary, + PowerPointExportSettings, ProjectBuildResult, ProjectComponentEntry, ProjectPaths, @@ -27,6 +28,7 @@ "ExperimentalDataSummary", "ClusterImportResult", "DreamBestFitSelection", + "PowerPointExportSettings", "ProjectBuildResult", "ProjectComponentEntry", "ProjectPaths", diff --git a/src/saxshell/saxs/project_manager/prior_plot.py b/src/saxshell/saxs/project_manager/prior_plot.py index 1b9a1e8..4d5a8de 100644 --- a/src/saxshell/saxs/project_manager/prior_plot.py +++ b/src/saxshell/saxs/project_manager/prior_plot.py @@ -24,12 +24,16 @@ def _natural_sort_key(value: str) -> list[object]: ] -def _load_prior_payload(json_path: str | Path) -> dict[str, object]: +def _load_prior_payload( + json_path: str | Path | dict[str, object], +) -> dict[str, object]: + if isinstance(json_path, dict): + return dict(json_path) return json.loads(Path(json_path).read_text(encoding="utf-8")) def export_prior_plot_data( - json_path: str | Path, + json_path: str | Path | dict[str, object], output_path: str | Path, *, mode: str = "structure_fraction", @@ -48,7 +52,7 @@ def export_prior_plot_data( ) for motif in sorted(motifs, key=_natural_sort_key): motif_payload = motifs[motif] - count = int(motif_payload.get("count", 0)) + count = float(motif_payload.get("count", 0.0) or 0.0) structure_fraction = count / total_files if total_files else 0.0 atom_fraction = ( (count * atom_count) @@ -60,7 +64,7 @@ def export_prior_plot_data( ), 1, ) - * int(other_payload.get("count", 0)) + * float(other_payload.get("count", 0.0) or 0.0) for other_label, motif_dict in structures.items() for other_payload in motif_dict.values() ) @@ -102,7 +106,7 @@ def export_prior_plot_data( def build_prior_histogram_export_payload( - json_path: str | Path, + json_path: str | Path | dict[str, object], *, mode: str = "structure_fraction", value_mode: str = "percent", @@ -132,7 +136,7 @@ def build_prior_histogram_export_payload( total_files = float(payload.get("total_files", 0) or 0.0) atom_weight_total = sum( max(sum(int(token) for token in re.findall(r"(\d+)", label)), 1) - * int(motif_payload.get("count", 0)) + * float(motif_payload.get("count", 0.0) or 0.0) for label, motif_dict in structures.items() for motif_payload in motif_dict.values() ) @@ -171,8 +175,9 @@ def build_prior_histogram_export_payload( int(segment), ) else: - count = int( - structures[label].get(str(segment), {}).get("count", 0) + count = float( + structures[label].get(str(segment), {}).get("count", 0.0) + or 0.0 ) if is_atom_fraction: base_value = float(count * atom_count) @@ -200,7 +205,7 @@ def build_prior_histogram_export_payload( def export_prior_histogram_table( - json_path: str | Path, + json_path: str | Path | dict[str, object], output_path: str | Path, *, mode: str = "structure_fraction", @@ -256,7 +261,7 @@ def export_prior_histogram_table( def export_prior_histogram_npy( - json_path: str | Path, + json_path: str | Path | dict[str, object], output_path: str | Path, *, mode: str = "structure_fraction", @@ -276,7 +281,7 @@ def export_prior_histogram_npy( def plot_md_prior_histogram( - json_path: str | Path, + json_path: str | Path | dict[str, object], *, mode: str = "structure_fraction", secondary_element: str | None = None, @@ -399,7 +404,7 @@ def plot_md_prior_histogram( def list_secondary_filter_elements( - json_path: str | Path, + json_path: str | Path | dict[str, object], ) -> list[str]: payload = _load_prior_payload(json_path) return _payload_secondary_filter_elements(payload) @@ -456,20 +461,20 @@ def _secondary_segment_count( motif_payloads: dict[str, object], secondary_element: str | None, segment_value: int, -) -> int: +) -> float: if secondary_element is None: - return 0 + return 0.0 count_key = str(int(segment_value)) - total = 0 + total = 0.0 for motif_payload in motif_payloads.values(): secondary_distributions = motif_payload.get( "secondary_atom_distributions", {}, ) - total += int( + total += float( secondary_distributions.get(secondary_element, {}).get( count_key, - 0, + 0.0, ) ) return total diff --git a/src/saxshell/saxs/project_manager/project.py b/src/saxshell/saxs/project_manager/project.py index 63e3104..9251b30 100644 --- a/src/saxshell/saxs/project_manager/project.py +++ b/src/saxshell/saxs/project_manager/project.py @@ -200,10 +200,149 @@ def resolved_run_dir(self, project_dir: str | Path) -> Path: ).resolve() +@dataclass(slots=True) +class PowerPointExportSettings: + font_family: str = "Arial" + component_color_map: str = "viridis" + prior_histogram_color_map: str = "viridis" + solvent_sort_histogram_color_map: str = "summer" + text_color: str = "#1f2933" + experimental_trace_color: str = "#111827" + model_trace_color: str = "#86d549" + residual_trace_color: str = "#375a8c" + solvent_trace_color: str = "#20a386" + structure_factor_color: str = "#e5e419" + table_header_fill: str = "#E5E7EB" + table_even_row_fill: str = "#FFFFFF" + table_odd_row_fill: str = "#F3F4F6" + table_rule_color: str = "#4B5563" + include_prior_histograms: bool = True + include_initial_traces: bool = True + include_prefit_model: bool = True + include_prefit_parameters: bool = True + include_geometry_table: bool = True + include_estimator_metrics: bool = True + include_dream_settings: bool = True + include_dream_prior_table: bool = True + include_dream_output_model: bool = True + include_posterior_comparisons: bool = True + include_output_summary: bool = True + include_directory_summary: bool = True + generate_manifest: bool = True + export_figure_assets: bool = True + + def to_dict(self) -> dict[str, object]: + return asdict(self) + + @classmethod + def from_dict(cls, payload: object) -> "PowerPointExportSettings": + if not isinstance(payload, dict): + return cls() + return cls( + font_family=_normalized_nonempty_text( + payload.get("font_family"), + default="Arial", + ), + component_color_map=_normalized_nonempty_text( + payload.get("component_color_map"), + default="viridis", + ), + prior_histogram_color_map=_normalized_nonempty_text( + payload.get("prior_histogram_color_map"), + default="viridis", + ), + solvent_sort_histogram_color_map=_normalized_nonempty_text( + payload.get("solvent_sort_histogram_color_map"), + default="summer", + ), + text_color=_normalized_hex_color( + payload.get("text_color"), + default="#1f2933", + ), + experimental_trace_color=_normalized_hex_color( + payload.get("experimental_trace_color"), + default="#111827", + ), + model_trace_color=_normalized_hex_color( + payload.get("model_trace_color"), + default="#86d549", + ), + residual_trace_color=_normalized_hex_color( + payload.get("residual_trace_color"), + default="#375a8c", + ), + solvent_trace_color=_normalized_hex_color( + payload.get("solvent_trace_color"), + default="#20a386", + ), + structure_factor_color=_normalized_hex_color( + payload.get("structure_factor_color"), + default="#e5e419", + ), + table_header_fill=_normalized_hex_color( + payload.get("table_header_fill"), + default="#E5E7EB", + ), + table_even_row_fill=_normalized_hex_color( + payload.get("table_even_row_fill"), + default="#FFFFFF", + ), + table_odd_row_fill=_normalized_hex_color( + payload.get("table_odd_row_fill"), + default="#F3F4F6", + ), + table_rule_color=_normalized_hex_color( + payload.get("table_rule_color"), + default="#4B5563", + ), + include_prior_histograms=bool( + payload.get("include_prior_histograms", True) + ), + include_initial_traces=bool( + payload.get("include_initial_traces", True) + ), + include_prefit_model=bool( + payload.get("include_prefit_model", True) + ), + include_prefit_parameters=bool( + payload.get("include_prefit_parameters", True) + ), + include_geometry_table=bool( + payload.get("include_geometry_table", True) + ), + include_estimator_metrics=bool( + payload.get("include_estimator_metrics", True) + ), + include_dream_settings=bool( + payload.get("include_dream_settings", True) + ), + include_dream_prior_table=bool( + payload.get("include_dream_prior_table", True) + ), + include_dream_output_model=bool( + payload.get("include_dream_output_model", True) + ), + include_posterior_comparisons=bool( + payload.get("include_posterior_comparisons", True) + ), + include_output_summary=bool( + payload.get("include_output_summary", True) + ), + include_directory_summary=bool( + payload.get("include_directory_summary", True) + ), + generate_manifest=bool(payload.get("generate_manifest", True)), + export_figure_assets=bool( + payload.get("export_figure_assets", True) + ), + ) + + @dataclass(slots=True) class ProjectSettings: project_name: str project_dir: str + model_only_mode: bool = False clusters_dir: str | None = None experimental_data_path: str | None = None copied_experimental_data_file: str | None = None @@ -248,6 +387,9 @@ class ProjectSettings: ) selected_model_template: str | None = None autosave_prefits: bool = False + powerpoint_export_settings: PowerPointExportSettings = field( + default_factory=PowerPointExportSettings + ) @property def resolved_project_dir(self) -> Path: @@ -322,6 +464,9 @@ def to_dict(self) -> dict[str, object]: payload["dream_favorite_history"] = [ entry.to_dict() for entry in self.dream_favorite_history ] + payload["powerpoint_export_settings"] = ( + self.powerpoint_export_settings.to_dict() + ) return payload @classmethod @@ -329,6 +474,7 @@ def from_dict(cls, payload: dict[str, object]) -> "ProjectSettings": return cls( project_name=str(payload.get("project_name", "SAXS Project")), project_dir=str(payload.get("project_dir", "")), + model_only_mode=bool(payload.get("model_only_mode", False)), clusters_dir=_optional_str(payload.get("clusters_dir")), experimental_data_path=_optional_str( payload.get("experimental_data_path") @@ -431,6 +577,9 @@ def from_dict(cls, payload: dict[str, object]) -> "ProjectSettings": payload.get("selected_model_template") ), autosave_prefits=bool(payload.get("autosave_prefits", False)), + powerpoint_export_settings=PowerPointExportSettings.from_dict( + payload.get("powerpoint_export_settings", {}) + ), ) @@ -453,6 +602,20 @@ def _optional_int(value: object) -> int | None: return int(value) +def _normalized_nonempty_text(value: object, *, default: str) -> str: + text = _optional_str(value) + return text if text is not None else default + + +def _normalized_hex_color(value: object, *, default: str) -> str: + text = _optional_str(value) + if text is None: + return default + if re.fullmatch(r"#[0-9a-fA-F]{6}", text): + return text.upper() + return default + + def _normalized_elements(values: object) -> list[str]: if isinstance(values, str): raw_values = [ @@ -718,9 +881,12 @@ def build_scattering_components( ) -> ProjectBuildResult: paths = build_project_paths(settings.project_dir) self.ensure_project_dirs(paths) - staged_data_path = self.stage_experimental_data(settings) - self.stage_solvent_data(settings) - experimental_data = self.load_experimental_data(settings) + staged_data_path: Path | None = None + experimental_data: ExperimentalDataSummary | None = None + if not settings.model_only_mode: + staged_data_path = self.stage_experimental_data(settings) + self.stage_solvent_data(settings) + experimental_data = self.load_experimental_data(settings) q_values = self._build_q_grid(settings, experimental_data) builder = DebyeProfileBuilder( q_values=q_values, @@ -776,9 +942,12 @@ def generate_prior_weights( ) -> ProjectBuildResult: paths = build_project_paths(settings.project_dir) self.ensure_project_dirs(paths) - staged_data_path = self.stage_experimental_data(settings) - self.stage_solvent_data(settings) - experimental_data = self.load_experimental_data(settings) + staged_data_path: Path | None = None + experimental_data: ExperimentalDataSummary | None = None + if not settings.model_only_mode: + staged_data_path = self.stage_experimental_data(settings) + self.stage_solvent_data(settings) + experimental_data = self.load_experimental_data(settings) q_values = self._build_q_grid(settings, experimental_data) clusters_dir = settings.resolved_clusters_dir if clusters_dir is None: @@ -893,8 +1062,11 @@ def _resolve_solvent_source( def _build_q_grid( self, settings: ProjectSettings, - experimental_data: ExperimentalDataSummary, + experimental_data: ExperimentalDataSummary | None, ) -> np.ndarray: + if experimental_data is None: + return self._build_model_only_q_grid(settings) + q_values = experimental_data.q_values q_min = ( settings.q_min @@ -930,6 +1102,43 @@ def _build_q_grid( ) return filtered_q + def _build_model_only_q_grid( + self, + settings: ProjectSettings, + ) -> np.ndarray: + supported_range = load_built_component_q_range(settings.project_dir) + q_min = ( + float(settings.q_min) + if settings.q_min is not None + else ( + float(supported_range[0]) + if supported_range is not None + else None + ) + ) + q_max = ( + float(settings.q_max) + if settings.q_max is not None + else ( + float(supported_range[1]) + if supported_range is not None + else None + ) + ) + if q_min is None or q_max is None: + raise ValueError( + "Model Only Mode requires q min and q max before SAXS " + "components or prior weights can be generated." + ) + if q_min > q_max: + raise ValueError("q min must be less than or equal to q max.") + q_points = ( + int(settings.q_points) + if settings.q_points is not None and settings.q_points > 1 + else 500 + ) + return np.linspace(q_min, q_max, q_points) + def _component_entries_from_clusters( self, clusters_dir: Path, diff --git a/src/saxshell/saxs/solute_volume_fraction.py b/src/saxshell/saxs/solute_volume_fraction.py index 62b7a2a..52b7e2d 100644 --- a/src/saxshell/saxs/solute_volume_fraction.py +++ b/src/saxshell/saxs/solute_volume_fraction.py @@ -53,7 +53,7 @@ class SoluteVolumeFractionEstimate: def summary_text(self) -> str: lines = [ - "Volume fraction estimate", + "Physical solute-associated volume fraction estimate", f"Mode: {self.solution_result.mode}", ( "Solution volume from measured density: " @@ -149,19 +149,28 @@ def summary_text(self) -> str: lines.extend( [ ( - "Estimated solute volume fraction: " + "Physical solute-associated volume fraction: " f"{_format_fraction(self.solute_volume_fraction)}" ), + ( + "Physical solvent-associated volume fraction: " + f"{_format_fraction(self.solvent_volume_fraction)}" + ), "", "Interpretation:", - "This estimate follows the SAXS-style concentration x " - "specific-volume picture for solute occupancy in the measured " - "solution volume:", + "This is the physical occupancy estimate from the bulk " + "density/composition model. SAXSShell keeps reporting it for " + "reference, but the poly-LMA model-facing phi_solute / " + "phi_solvent defaults now come from the SAXS-effective " + "contrast-weighted interaction ratio when that estimate is " + "available.", + "", + "Bulk-density relation:", ( - "phi_solute ~= c_solute * vbar_solute " + "phi_phys ~= c_solute * vbar_solute " "= (m_solute / V_solution) * (1 / rho_solute)." if self.calculation_method == "solute_density" - else "phi_solute ~= V_solute / V_solution, with " + else "phi_phys ~= V_solute / V_solution, with " "V_solute ~= V_solution - (m_solvent / rho_solvent)." ), ] diff --git a/src/saxshell/saxs/solution_scattering_estimator.py b/src/saxshell/saxs/solution_scattering_estimator.py new file mode 100644 index 0000000..f5e7d32 --- /dev/null +++ b/src/saxshell/saxs/solution_scattering_estimator.py @@ -0,0 +1,1270 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +import numpy as np +import xraydb + +from saxshell.fullrmc.solution_properties import ( + SolutionPropertiesResult, + SolutionPropertiesSettings, + calculate_solution_properties, +) +from saxshell.saxs.solute_volume_fraction import ( + SoluteVolumeFractionEstimate, + SoluteVolumeFractionSettings, + calculate_solute_volume_fraction_estimate, +) + +DEFAULT_INCIDENT_ENERGY_KEV = 17.0 +DEFAULT_CAPILLARY_SIZE_MM = 1.0 +DEFAULT_BEAM_FOOTPRINT_WIDTH_MM = 0.4 +DEFAULT_BEAM_FOOTPRINT_HEIGHT_MM = 0.4 +DEFAULT_CAPILLARY_GEOMETRY = "cylindrical" +DEFAULT_BEAM_PROFILE = "uniform" +PATH_SAMPLE_COUNT = 801 + +CAPILLARY_GEOMETRY_ITEMS = ( + ("Cylindrical", "cylindrical"), + ("Flat plate", "flat_plate"), +) +BEAM_PROFILE_ITEMS = (("Uniform", "uniform"),) +EDGE_LINE_FAMILIES = { + "K": ("Ka", "Kb"), + "L3": ("La", "Lb"), + "L2": ("Lb", "Lg"), + "L1": ("Lb", "Lg"), + "M5": ("Ma",), + "M4": ("Mb",), +} +MINIMUM_LINE_CONTRIBUTION = 1e-15 +ENERGY_WAVELENGTH_KEV_ANGSTROM = 12.398419843320026 +AVOGADRO_NUMBER = 6.02214076e23 + + +def _format_number(value: float, digits: int = 6) -> str: + return f"{float(value):.{digits}g}" + + +def _format_fraction(value: float) -> str: + return f"{float(value):.6f}" + + +def _format_scattering_density(value: float) -> str: + return f"{float(value):.6e}" + + +def wavelength_angstrom_from_energy_kev(energy_kev: float) -> float: + validated_energy = _validate_positive(energy_kev, "Incident energy") + return float(ENERGY_WAVELENGTH_KEV_ANGSTROM / validated_energy) + + +@dataclass(slots=True) +class BeamGeometrySettings: + incident_energy_kev: float = DEFAULT_INCIDENT_ENERGY_KEV + capillary_size_mm: float = DEFAULT_CAPILLARY_SIZE_MM + capillary_geometry: str = DEFAULT_CAPILLARY_GEOMETRY + beam_profile: str = DEFAULT_BEAM_PROFILE + beam_footprint_width_mm: float = DEFAULT_BEAM_FOOTPRINT_WIDTH_MM + beam_footprint_height_mm: float = DEFAULT_BEAM_FOOTPRINT_HEIGHT_MM + + +@dataclass(slots=True) +class SolutionScatteringEstimatorSettings: + solution: SolutionPropertiesSettings + solute_density_g_per_ml: float | None = 1.0 + solvent_density_g_per_ml: float | None = 1.0 + calculate_number_density: bool = True + calculate_solute_volume_fraction: bool = True + calculate_solvent_scattering_contribution: bool = True + calculate_sample_fluorescence_yield: bool = False + beam: BeamGeometrySettings = field(default_factory=BeamGeometrySettings) + + +@dataclass(slots=True) +class NumberDensityEstimate: + number_density_cm3: float + number_density_a3: float + total_atoms: float + element_ratio_string: str + + def summary_text(self) -> str: + lines = [ + "Number density estimate", + ( + "Atomic number density: " + f"{_format_number(self.number_density_a3)} atoms/A^3" + ), + ( + "Atomic number density: " + f"{_format_number(self.number_density_cm3)} atoms/cm^3" + ), + ( + "Total atoms in source solution: " + f"{_format_number(self.total_atoms)}" + ), + ] + if self.element_ratio_string: + lines.append( + "Integer ratio of elements: " f"{self.element_ratio_string}" + ) + return "\n".join(lines) + + +@dataclass(slots=True) +class SAXSInteractionEstimate: + incident_energy_kev: float + wavelength_angstrom: float + solute_formula: str + solvent_formula: str + solute_density_g_per_cm3: float + solvent_density_g_per_cm3: float + solute_effective_scattering_density_electrons_per_cm3: float + solvent_effective_scattering_density_electrons_per_cm3: float + contrast_scattering_density_electrons_per_cm3: float + physical_solute_associated_volume_cm3: float + physical_solvent_associated_volume_cm3: float + physical_solute_associated_volume_fraction: float + physical_solvent_associated_volume_fraction: float + contrast_weight_factor: float + effective_solute_interaction_volume_cm3: float + effective_solvent_background_volume_cm3: float + saxs_effective_solute_interaction_ratio: float + saxs_effective_solvent_background_ratio: float + + def summary_text(self) -> str: + lines = [ + "SAXS-effective interaction contrast estimate", + f"Incident energy: {self.incident_energy_kev:.6g} keV", + f"Wavelength: {self.wavelength_angstrom:.6g} A", + ( + "Solute effective scattering density: " + f"{_format_scattering_density(self.solute_effective_scattering_density_electrons_per_cm3)} " + "electrons/cm^3" + ), + ( + "Solvent effective scattering density: " + f"{_format_scattering_density(self.solvent_effective_scattering_density_electrons_per_cm3)} " + "electrons/cm^3" + ), + ( + "Energy-dependent solute-solvent contrast: " + f"{_format_scattering_density(self.contrast_scattering_density_electrons_per_cm3)} " + "electrons/cm^3" + ), + ( + "Physical solute-associated volume fraction: " + f"{_format_fraction(self.physical_solute_associated_volume_fraction)}" + ), + ( + "Physical solvent-associated volume fraction: " + f"{_format_fraction(self.physical_solvent_associated_volume_fraction)}" + ), + ( + "Contrast weight factor: " + f"{_format_number(self.contrast_weight_factor)}" + ), + ( + "Effective solute interaction volume: " + f"{self.effective_solute_interaction_volume_cm3:.6f} cm^3" + ), + ( + "SAXS-effective solute interaction ratio: " + f"{_format_fraction(self.saxs_effective_solute_interaction_ratio)}" + ), + ( + "SAXS-effective solvent background ratio: " + f"{_format_fraction(self.saxs_effective_solvent_background_ratio)}" + ), + "", + "Interpretation:", + "The physical fraction above is still reported for transparency, " + "but the model-facing phi_solute / phi_solvent default is based " + "on the contrast-weighted SAXS interaction ratio at the selected " + "energy.", + "SAXS-effective ratio model:", + ("rho_eff(E) = rho_mass * N_A / M * sum_i n_i [Z_i + f'_i(E)]"), + ( + "C(E) = ((rho_eff,solute(E) - rho_eff,solvent(E)) / " + "rho_eff,solvent(E))^2" + ), + "V_eff,solute(E) = C(E) * V_solute,phys", + ( + "R_saxs(E) = V_eff,solute(E) / " + "(V_eff,solute(E) + V_solvent,phys)" + ), + "Prefit uses R_saxs(E) as the default model solute fraction when " + "the active template exposes phi_solute / phi_solvent.", + ] + return "\n".join(lines) + + +@dataclass(slots=True) +class AttenuationEstimate: + incident_energy_kev: float + capillary_geometry: str + sample_average_path_length_mm: float + sample_linear_attenuation_inv_cm: float + solute_linear_attenuation_inv_cm: float + sample_solvent_linear_attenuation_inv_cm: float + neat_solvent_linear_attenuation_inv_cm: float + sample_transmission: float + solute_only_transmission: float + sample_solvent_only_transmission: float + neat_solvent_transmission: float + sample_scattering_weight: float + neat_solvent_scattering_weight: float + solvent_mass_concentration_g_per_cm3: float + neat_solvent_density_g_per_cm3: float + solvent_scattering_scale_factor: float + neat_solvent_to_sample_ratio: float | None + + def summary_text(self) -> str: + lines = [ + "Attenuation and solvent contribution estimate", + f"Incident energy: {self.incident_energy_kev:.6g} keV", + f"Capillary geometry: {self.capillary_geometry}", + ( + "Average illuminated path length: " + f"{self.sample_average_path_length_mm:.4f} mm" + ), + ( + "Sample total linear attenuation: " + f"{_format_number(self.sample_linear_attenuation_inv_cm)} 1/cm" + ), + ( + "Solute-only linear attenuation in sample: " + f"{_format_number(self.solute_linear_attenuation_inv_cm)} 1/cm" + ), + ( + "Solvent-only linear attenuation in sample: " + f"{_format_number(self.sample_solvent_linear_attenuation_inv_cm)} 1/cm" + ), + ( + "Neat-solvent linear attenuation: " + f"{_format_number(self.neat_solvent_linear_attenuation_inv_cm)} 1/cm" + ), + ( + "Sample total transmission: " + f"{_format_fraction(self.sample_transmission)}" + ), + ( + "Solute-only transmission in sample: " + f"{_format_fraction(self.solute_only_transmission)}" + ), + ( + "Solvent-only transmission in sample: " + f"{_format_fraction(self.sample_solvent_only_transmission)}" + ), + ( + "Neat-solvent transmission: " + f"{_format_fraction(self.neat_solvent_transmission)}" + ), + ( + "Sample solvent mass concentration: " + f"{_format_number(self.solvent_mass_concentration_g_per_cm3)} g/cm^3" + ), + ( + "Neat-solvent density: " + f"{_format_number(self.neat_solvent_density_g_per_cm3)} g/cm^3" + ), + ( + "Recommended solvent scattering scale factor: " + f"{_format_fraction(self.solvent_scattering_scale_factor)}" + ), + ] + if self.neat_solvent_to_sample_ratio is not None: + lines.append( + "Neat-solvent / sample-solvent intensity ratio: " + f"{_format_number(self.neat_solvent_to_sample_ratio)}" + ) + lines.extend( + [ + "", + "Interpretation:", + "The solvent scale factor is the estimated proportionality " + "needed to reduce the neat-solvent trace to the solvent " + "contribution present inside the measured sample.", + ] + ) + return "\n".join(lines) + + +@dataclass(slots=True) +class FluorescenceLineEstimate: + element: str + edge: str + family: str + line_energy_ev: float + primary_detected_yield: float + secondary_detected_yield: float + + @property + def total_detected_yield(self) -> float: + return float( + self.primary_detected_yield + self.secondary_detected_yield + ) + + +@dataclass(slots=True) +class FluorescenceEstimate: + incident_energy_kev: float + capillary_geometry: str + total_primary_detected_yield: float + total_secondary_detected_yield: float + line_estimates: list[FluorescenceLineEstimate] + + def summary_text(self) -> str: + lines = [ + "Fluorescence yield estimate", + f"Incident energy: {self.incident_energy_kev:.6g} keV", + f"Capillary geometry: {self.capillary_geometry}", + ( + "Detected primary fluorescence proxy: " + f"{_format_number(self.total_primary_detected_yield)}" + ), + ( + "Detected secondary fluorescence proxy: " + f"{_format_number(self.total_secondary_detected_yield)}" + ), + ( + "Detected total fluorescence proxy: " + f"{_format_number(self.total_primary_detected_yield + self.total_secondary_detected_yield)}" + ), + "", + "Strongest line families:", + ] + for estimate in sorted( + self.line_estimates, + key=lambda item: item.total_detected_yield, + reverse=True, + )[:12]: + lines.append( + f" {estimate.element} {estimate.family} " + f"({estimate.edge}, {estimate.line_energy_ev:.1f} eV): " + f"primary={_format_number(estimate.primary_detected_yield)}, " + f"secondary={_format_number(estimate.secondary_detected_yield)}, " + f"total={_format_number(estimate.total_detected_yield)}" + ) + lines.extend( + [ + "", + "Interpretation:", + "This is a first-order fluorescence-background proxy. The " + "primary term uses edge jump-ratio partitioning together " + "with Elam-style fluorescence yields and line branching. The " + "secondary term is a single re-absorption and re-emission " + "pass, not a full Monte Carlo transport model.", + ] + ) + return "\n".join(lines) + + +@dataclass(slots=True) +class SolutionScatteringEstimate: + settings: SolutionScatteringEstimatorSettings + solution_result: SolutionPropertiesResult + number_density_estimate: NumberDensityEstimate | None = None + volume_fraction_estimate: SoluteVolumeFractionEstimate | None = None + interaction_contrast_estimate: SAXSInteractionEstimate | None = None + attenuation_estimate: AttenuationEstimate | None = None + fluorescence_estimate: FluorescenceEstimate | None = None + + def summary_text(self) -> str: + sections = ["Solution scattering estimator"] + if self.number_density_estimate is not None: + sections.append(self.number_density_estimate.summary_text()) + if self.volume_fraction_estimate is not None: + sections.append(self.volume_fraction_estimate.summary_text()) + if self.interaction_contrast_estimate is not None: + sections.append(self.interaction_contrast_estimate.summary_text()) + if self.attenuation_estimate is not None: + sections.append(self.attenuation_estimate.summary_text()) + if self.interaction_contrast_estimate is not None: + sections.append( + "\n".join( + [ + "Model-facing solvent defaults", + ( + "Split-fraction templates " + "(phi_solute / phi_solvent + solvent_scale): " + f"phi ratio = {_format_fraction(self.interaction_contrast_estimate.saxs_effective_solute_interaction_ratio)}, " + f"solvent_scale = {_format_fraction(self.attenuation_estimate.solvent_scattering_scale_factor)}" + ), + ( + "Single-solvent-weight templates " + "(solv_w only): " + f"solvent multiplier = {_format_fraction(self.attenuation_estimate.solvent_scattering_scale_factor * self.interaction_contrast_estimate.saxs_effective_solvent_background_ratio)}" + ), + ] + ) + ) + if self.fluorescence_estimate is not None: + sections.append(self.fluorescence_estimate.summary_text()) + return "\n\n".join(section.strip() for section in sections if section) + + +def _build_number_density_estimate( + solution_result: SolutionPropertiesResult, +) -> NumberDensityEstimate: + return NumberDensityEstimate( + number_density_cm3=float(solution_result.number_density_cm3), + number_density_a3=float(solution_result.number_density_a3), + total_atoms=float(solution_result.total_atoms), + element_ratio_string=str(solution_result.element_ratio_string).strip(), + ) + + +def _validate_positive(value: float, label: str) -> float: + number = float(value) + if number <= 0.0: + raise ValueError(f"{label} must be greater than zero.") + return number + + +def _validate_supported_formula(formula: str, label: str) -> str: + text = str(formula or "").strip() + if not text: + raise ValueError(f"{label} formula is required.") + if not xraydb.validate_formula(text): + raise ValueError( + f"{label} formula {text!r} is not recognized by the X-ray " + "attenuation database. Enter an empirical formula such as H2O " + "or C3H7NO." + ) + return text + + +def _effective_scattering_density( + formula: str, + energy_ev: float, + density_g_per_cm3: float, +) -> float: + if density_g_per_cm3 <= 0.0: + raise ValueError("Density must be greater than zero.") + try: + composition = xraydb.chemparse(formula) + except Exception as exc: + raise ValueError( + f"Unable to parse formula {formula!r} for scattering-density " + "estimation." + ) from exc + if not composition: + raise ValueError( + f"Formula {formula!r} could not be parsed for scattering-density " + "estimation." + ) + formula_mass = 0.0 + effective_electrons = 0.0 + for element, abundance in composition.items(): + amount = float(abundance) + formula_mass += amount * float(xraydb.atomic_mass(element)) + effective_electrons += amount * float( + xraydb.atomic_number(element) + + xraydb.f1_chantler(element, energy_ev) + ) + if formula_mass <= 0.0: + raise ValueError( + f"Formula {formula!r} produced a non-positive molar mass." + ) + return float( + float(density_g_per_cm3) + * AVOGADRO_NUMBER + * effective_electrons + / formula_mass + ) + + +def _normalize_geometry_name(value: str) -> str: + normalized = str(value or DEFAULT_CAPILLARY_GEOMETRY).strip().lower() + if normalized not in {"cylindrical", "flat_plate"}: + raise ValueError( + "Capillary geometry must be 'cylindrical' or 'flat_plate'." + ) + return normalized + + +def _normalize_profile_name(value: str) -> str: + normalized = str(value or DEFAULT_BEAM_PROFILE).strip().lower() + if normalized != "uniform": + raise ValueError( + "Only a uniform beam profile is currently implemented." + ) + return normalized + + +def _path_lengths_cm(settings: BeamGeometrySettings) -> np.ndarray: + profile = _normalize_profile_name(settings.beam_profile) + del profile + geometry = _normalize_geometry_name(settings.capillary_geometry) + capillary_size_cm = ( + _validate_positive( + settings.capillary_size_mm, + "Capillary size", + ) + * 0.1 + ) + beam_width_cm = ( + _validate_positive( + settings.beam_footprint_width_mm, + "Beam footprint width", + ) + * 0.1 + ) + _validate_positive( + settings.beam_footprint_height_mm, + "Beam footprint height", + ) + if geometry == "flat_plate": + return np.full(PATH_SAMPLE_COUNT, capillary_size_cm, dtype=float) + + radius_cm = 0.5 * capillary_size_cm + x_values = np.linspace( + -0.5 * beam_width_cm, + 0.5 * beam_width_cm, + PATH_SAMPLE_COUNT, + dtype=float, + ) + path_lengths = np.zeros_like(x_values) + inside = np.abs(x_values) <= radius_cm + path_lengths[inside] = 2.0 * np.sqrt( + np.clip(radius_cm**2 - x_values[inside] ** 2, 0.0, None) + ) + return path_lengths + + +def _average_transmission( + path_lengths_cm: np.ndarray, mu_inv_cm: float +) -> float: + return float(np.mean(np.exp(-float(mu_inv_cm) * path_lengths_cm))) + + +def _average_weighted_path( + path_lengths_cm: np.ndarray, + mu_inv_cm: float, +) -> float: + mu_inv_cm = float(mu_inv_cm) + return float( + np.mean(path_lengths_cm * np.exp(-mu_inv_cm * path_lengths_cm)) + ) + + +def _average_source_integral( + path_lengths_cm: np.ndarray, + mu_inv_cm: float, +) -> float: + mu_inv_cm = float(mu_inv_cm) + if abs(mu_inv_cm) <= 1e-15: + return float(np.mean(path_lengths_cm)) + return float( + np.mean((1.0 - np.exp(-mu_inv_cm * path_lengths_cm)) / mu_inv_cm) + ) + + +def _average_detected_integral( + path_lengths_cm: np.ndarray, + mu_in_inv_cm: float, + mu_out_inv_cm: float, +) -> float: + mu_in_inv_cm = float(mu_in_inv_cm) + mu_out_inv_cm = float(mu_out_inv_cm) + if abs(mu_in_inv_cm - mu_out_inv_cm) <= 1e-15: + return float( + np.mean(path_lengths_cm * np.exp(-mu_in_inv_cm * path_lengths_cm)) + ) + return float( + np.mean( + ( + np.exp(-mu_out_inv_cm * path_lengths_cm) + - np.exp(-mu_in_inv_cm * path_lengths_cm) + ) + / (mu_in_inv_cm - mu_out_inv_cm) + ) + ) + + +def _average_half_path_escape_factor( + path_lengths_cm: np.ndarray, + mu_inv_cm: float, +) -> float: + return float(np.mean(np.exp(-0.5 * float(mu_inv_cm) * path_lengths_cm))) + + +def _average_half_path_absorption_fraction( + path_lengths_cm: np.ndarray, + mu_inv_cm: float, +) -> float: + return float( + np.mean(1.0 - np.exp(-0.5 * float(mu_inv_cm) * path_lengths_cm)) + ) + + +def _material_mu( + formula: str, + energy_ev: float, + density_g_per_cm3: float, + *, + kind: str = "total", +) -> float: + if density_g_per_cm3 <= 0.0: + return 0.0 + try: + return float( + xraydb.material_mu( + formula, + float(energy_ev), + density=float(density_g_per_cm3), + kind=kind, + ) + ) + except Warning as exc: + raise ValueError(str(exc)) from exc + + +def _material_element_coefficients( + formula: str, + energy_ev: float, + density_g_per_cm3: float, + *, + kind: str = "total", +) -> dict[str, float]: + if density_g_per_cm3 <= 0.0: + return {} + try: + payload = xraydb.material_mu_components( + formula, + float(energy_ev), + density=float(density_g_per_cm3), + kind=kind, + ) + except Warning as exc: + raise ValueError(str(exc)) from exc + coefficients: dict[str, float] = {} + for element in payload.get("elements", []): + value = payload.get(str(element)) + if not isinstance(value, tuple) or len(value) < 3: + continue + coefficients[str(element)] = float(value[2]) + return coefficients + + +def _merged_element_coefficients( + contributions: list[dict[str, float]], +) -> dict[str, float]: + merged: dict[str, float] = {} + for contribution in contributions: + for element, value in contribution.items(): + merged[element] = merged.get(element, 0.0) + float(value) + return merged + + +def _accessible_edge_shares( + element: str, + incident_energy_ev: float, +) -> dict[str, float]: + raw_shares: dict[str, float] = {} + for edge_name, edge_data in xraydb.xray_edges(element).items(): + if edge_name not in EDGE_LINE_FAMILIES: + continue + if float(edge_data.energy) >= float(incident_energy_ev): + continue + jump_ratio = float(edge_data.jump_ratio) + if jump_ratio <= 1.0: + continue + raw_shares[str(edge_name)] = (jump_ratio - 1.0) / jump_ratio + if not raw_shares: + return {} + total = sum(raw_shares.values()) + return { + edge_name: value / total + for edge_name, value in raw_shares.items() + if value > 0.0 + } + + +def _solution_component_concentrations( + solution_result: SolutionPropertiesResult, +) -> tuple[float, float]: + volume_solution_cm3 = float(solution_result.volume_solution_cm3) + if volume_solution_cm3 <= 0.0: + raise ValueError("Solution volume must be greater than zero.") + return ( + float(solution_result.mass_solute) / volume_solution_cm3, + float(solution_result.mass_solvent) / volume_solution_cm3, + ) + + +def _calculate_saxs_interaction_estimate( + settings: SolutionScatteringEstimatorSettings, + volume_fraction_estimate: SoluteVolumeFractionEstimate, +) -> SAXSInteractionEstimate: + incident_energy_kev = _validate_positive( + settings.beam.incident_energy_kev, + "Incident energy", + ) + incident_energy_ev = incident_energy_kev * 1000.0 + solute_formula = _validate_supported_formula( + settings.solution.solute_stoich, + "Solute", + ) + solvent_formula = _validate_supported_formula( + settings.solution.solvent_stoich, + "Solvent", + ) + if settings.solute_density_g_per_ml is None: + specific_volume = float( + volume_fraction_estimate.approximate_solute_specific_volume_cm3_per_g + ) + if specific_volume <= 0.0: + raise ValueError( + "The SAXS-effective interaction estimate needs a positive " + "solute specific volume or solute density." + ) + solute_density = 1.0 / specific_volume + else: + solute_density = _validate_positive( + settings.solute_density_g_per_ml, + "Solute density", + ) + if settings.solvent_density_g_per_ml is not None: + solvent_density = _validate_positive( + settings.solvent_density_g_per_ml, + "Solvent density", + ) + else: + solvent_volume = float( + volume_fraction_estimate.solution_result.volume_solution_cm3 + - volume_fraction_estimate.solute_volume_cm3 + ) + if solvent_volume <= 0.0: + raise ValueError( + "The SAXS-effective interaction estimate needs a positive " + "solvent-associated volume." + ) + solvent_density = ( + float(volume_fraction_estimate.solution_result.mass_solvent) + / solvent_volume + ) + solute_scattering_density = _effective_scattering_density( + solute_formula, + incident_energy_ev, + solute_density, + ) + solvent_scattering_density = _effective_scattering_density( + solvent_formula, + incident_energy_ev, + solvent_density, + ) + if abs(solvent_scattering_density) <= 1e-30: + raise ValueError( + "The solvent effective scattering density is too small to form a " + "contrast-weighted SAXS interaction ratio." + ) + contrast_scattering_density = ( + solute_scattering_density - solvent_scattering_density + ) + contrast_weight_factor = ( + contrast_scattering_density / solvent_scattering_density + ) ** 2 + physical_solute_volume_cm3 = float( + volume_fraction_estimate.solute_volume_cm3 + ) + physical_solvent_volume_cm3 = float( + volume_fraction_estimate.solution_result.volume_solution_cm3 + - volume_fraction_estimate.solute_volume_cm3 + ) + if physical_solvent_volume_cm3 < 0.0: + physical_solvent_volume_cm3 = 0.0 + effective_solute_interaction_volume_cm3 = ( + physical_solute_volume_cm3 * contrast_weight_factor + ) + effective_solvent_background_volume_cm3 = physical_solvent_volume_cm3 + total_effective_volume = ( + effective_solute_interaction_volume_cm3 + + effective_solvent_background_volume_cm3 + ) + if total_effective_volume <= 0.0: + raise ValueError( + "The contrast-weighted interaction volume is non-positive." + ) + return SAXSInteractionEstimate( + incident_energy_kev=incident_energy_kev, + wavelength_angstrom=wavelength_angstrom_from_energy_kev( + incident_energy_kev + ), + solute_formula=solute_formula, + solvent_formula=solvent_formula, + solute_density_g_per_cm3=float(solute_density), + solvent_density_g_per_cm3=float(solvent_density), + solute_effective_scattering_density_electrons_per_cm3=( + solute_scattering_density + ), + solvent_effective_scattering_density_electrons_per_cm3=( + solvent_scattering_density + ), + contrast_scattering_density_electrons_per_cm3=( + contrast_scattering_density + ), + physical_solute_associated_volume_cm3=physical_solute_volume_cm3, + physical_solvent_associated_volume_cm3=physical_solvent_volume_cm3, + physical_solute_associated_volume_fraction=float( + volume_fraction_estimate.solute_volume_fraction + ), + physical_solvent_associated_volume_fraction=float( + volume_fraction_estimate.solvent_volume_fraction + ), + contrast_weight_factor=float(contrast_weight_factor), + effective_solute_interaction_volume_cm3=float( + effective_solute_interaction_volume_cm3 + ), + effective_solvent_background_volume_cm3=float( + effective_solvent_background_volume_cm3 + ), + saxs_effective_solute_interaction_ratio=float( + effective_solute_interaction_volume_cm3 / total_effective_volume + ), + saxs_effective_solvent_background_ratio=float( + effective_solvent_background_volume_cm3 / total_effective_volume + ), + ) + + +def _calculate_attenuation_estimate( + settings: SolutionScatteringEstimatorSettings, + solution_result: SolutionPropertiesResult, + path_lengths_cm: np.ndarray, +) -> AttenuationEstimate: + solvent_density = settings.solvent_density_g_per_ml + if solvent_density is None: + raise ValueError( + "Solvent density is required for the attenuation estimate." + ) + solvent_density = _validate_positive(solvent_density, "Solvent density") + incident_energy_ev = ( + _validate_positive( + settings.beam.incident_energy_kev, + "Incident energy", + ) + * 1000.0 + ) + solute_formula = _validate_supported_formula( + settings.solution.solute_stoich, + "Solute", + ) + solvent_formula = _validate_supported_formula( + settings.solution.solvent_stoich, + "Solvent", + ) + solute_concentration, solvent_concentration = ( + _solution_component_concentrations(solution_result) + ) + + mu_solute = _material_mu( + solute_formula, + incident_energy_ev, + solute_concentration, + kind="total", + ) + mu_sample_solvent = _material_mu( + solvent_formula, + incident_energy_ev, + solvent_concentration, + kind="total", + ) + mu_neat_solvent = _material_mu( + solvent_formula, + incident_energy_ev, + solvent_density, + kind="total", + ) + mu_sample = mu_solute + mu_sample_solvent + sample_scattering_weight = _average_weighted_path( + path_lengths_cm, + mu_sample, + ) + neat_solvent_scattering_weight = _average_weighted_path( + path_lengths_cm, + mu_neat_solvent, + ) + if neat_solvent_scattering_weight <= 0.0: + solvent_scale_factor = 0.0 + neat_to_sample_ratio = None + else: + solvent_scale_factor = ( + solvent_concentration * sample_scattering_weight + ) / (solvent_density * neat_solvent_scattering_weight) + neat_to_sample_ratio = ( + None if solvent_scale_factor <= 0.0 else 1.0 / solvent_scale_factor + ) + return AttenuationEstimate( + incident_energy_kev=float(settings.beam.incident_energy_kev), + capillary_geometry=_normalize_geometry_name( + settings.beam.capillary_geometry + ), + sample_average_path_length_mm=(float(np.mean(path_lengths_cm)) * 10.0), + sample_linear_attenuation_inv_cm=mu_sample, + solute_linear_attenuation_inv_cm=mu_solute, + sample_solvent_linear_attenuation_inv_cm=mu_sample_solvent, + neat_solvent_linear_attenuation_inv_cm=mu_neat_solvent, + sample_transmission=_average_transmission(path_lengths_cm, mu_sample), + solute_only_transmission=_average_transmission( + path_lengths_cm, + mu_solute, + ), + sample_solvent_only_transmission=_average_transmission( + path_lengths_cm, + mu_sample_solvent, + ), + neat_solvent_transmission=_average_transmission( + path_lengths_cm, + mu_neat_solvent, + ), + sample_scattering_weight=sample_scattering_weight, + neat_solvent_scattering_weight=neat_solvent_scattering_weight, + solvent_mass_concentration_g_per_cm3=solvent_concentration, + neat_solvent_density_g_per_cm3=solvent_density, + solvent_scattering_scale_factor=solvent_scale_factor, + neat_solvent_to_sample_ratio=neat_to_sample_ratio, + ) + + +def _calculate_fluorescence_estimate( + settings: SolutionScatteringEstimatorSettings, + solution_result: SolutionPropertiesResult, + path_lengths_cm: np.ndarray, +) -> FluorescenceEstimate: + incident_energy_ev = ( + _validate_positive( + settings.beam.incident_energy_kev, + "Incident energy", + ) + * 1000.0 + ) + solute_formula = _validate_supported_formula( + settings.solution.solute_stoich, + "Solute", + ) + solvent_formula = _validate_supported_formula( + settings.solution.solvent_stoich, + "Solvent", + ) + solute_concentration, solvent_concentration = ( + _solution_component_concentrations(solution_result) + ) + sample_total_mu_cache: dict[float, float] = {} + sample_photo_mu_cache: dict[float, dict[str, float]] = {} + + def sample_total_mu(energy_ev: float) -> float: + rounded = round(float(energy_ev), 6) + if rounded not in sample_total_mu_cache: + sample_total_mu_cache[rounded] = _material_mu( + solute_formula, + rounded, + solute_concentration, + kind="total", + ) + _material_mu( + solvent_formula, + rounded, + solvent_concentration, + kind="total", + ) + return sample_total_mu_cache[rounded] + + def sample_photo_element_mus(energy_ev: float) -> dict[str, float]: + rounded = round(float(energy_ev), 6) + if rounded not in sample_photo_mu_cache: + sample_photo_mu_cache[rounded] = _merged_element_coefficients( + [ + _material_element_coefficients( + solute_formula, + rounded, + solute_concentration, + kind="photo", + ), + _material_element_coefficients( + solvent_formula, + rounded, + solvent_concentration, + kind="photo", + ), + ] + ) + return sample_photo_mu_cache[rounded] + + mu_incident_total = sample_total_mu(incident_energy_ev) + source_integral = _average_source_integral( + path_lengths_cm, + mu_incident_total, + ) + line_lookup: dict[tuple[str, str, str], FluorescenceLineEstimate] = {} + primary_generated_records: list[tuple[str, float, float]] = [] + + for element, photo_mu in sample_photo_element_mus( + incident_energy_ev + ).items(): + if photo_mu <= 0.0: + continue + edge_shares = _accessible_edge_shares(element, incident_energy_ev) + if not edge_shares: + continue + for edge_name, edge_share in edge_shares.items(): + edge_photo_mu = float(photo_mu) * float(edge_share) + for family in EDGE_LINE_FAMILIES.get(edge_name, ()): + try: + fyield, line_energy_ev, line_probability = ( + xraydb.fluor_yield( + element, + edge_name, + family, + incident_energy_ev, + ) + ) + except Exception: + continue + line_branch = float(fyield) * float(line_probability) + if ( + line_energy_ev <= 0.0 + or line_branch <= MINIMUM_LINE_CONTRIBUTION + ): + continue + mu_line_total = sample_total_mu(line_energy_ev) + primary_generated = ( + edge_photo_mu * line_branch * source_integral + ) + primary_detected = ( + edge_photo_mu + * line_branch + * _average_detected_integral( + path_lengths_cm, + mu_incident_total, + mu_line_total, + ) + ) + if ( + primary_generated <= MINIMUM_LINE_CONTRIBUTION + and primary_detected <= MINIMUM_LINE_CONTRIBUTION + ): + continue + key = (str(element), str(edge_name), str(family)) + line_lookup[key] = FluorescenceLineEstimate( + element=str(element), + edge=str(edge_name), + family=str(family), + line_energy_ev=float(line_energy_ev), + primary_detected_yield=float(primary_detected), + secondary_detected_yield=0.0, + ) + primary_generated_records.append( + ( + str(element), + float(line_energy_ev), + float(primary_generated), + ) + ) + + for ( + source_element, + line_energy_ev, + primary_generated, + ) in primary_generated_records: + if primary_generated <= MINIMUM_LINE_CONTRIBUTION: + continue + mu_line_total = sample_total_mu(line_energy_ev) + reabsorbed_fraction = _average_half_path_absorption_fraction( + path_lengths_cm, + mu_line_total, + ) + if reabsorbed_fraction <= MINIMUM_LINE_CONTRIBUTION: + continue + photo_element_mus = sample_photo_element_mus(line_energy_ev) + total_photo_mu = sum(photo_element_mus.values()) + if total_photo_mu <= 0.0: + continue + absorbed_primary = primary_generated * reabsorbed_fraction + for target_element, target_photo_mu in photo_element_mus.items(): + if target_element == source_element or target_photo_mu <= 0.0: + continue + element_absorption_share = float(target_photo_mu) / float( + total_photo_mu + ) + edge_shares = _accessible_edge_shares( + target_element, line_energy_ev + ) + if not edge_shares: + continue + for edge_name, edge_share in edge_shares.items(): + for family in EDGE_LINE_FAMILIES.get(edge_name, ()): + try: + fyield, emitted_energy_ev, line_probability = ( + xraydb.fluor_yield( + target_element, + edge_name, + family, + line_energy_ev, + ) + ) + except Exception: + continue + line_branch = float(fyield) * float(line_probability) + if ( + emitted_energy_ev <= 0.0 + or line_branch <= MINIMUM_LINE_CONTRIBUTION + ): + continue + secondary_generated = ( + absorbed_primary + * element_absorption_share + * float(edge_share) + * line_branch + ) + if secondary_generated <= MINIMUM_LINE_CONTRIBUTION: + continue + mu_secondary_total = sample_total_mu(emitted_energy_ev) + secondary_detected = secondary_generated * ( + _average_half_path_escape_factor( + path_lengths_cm, + mu_secondary_total, + ) + ) + if secondary_detected <= MINIMUM_LINE_CONTRIBUTION: + continue + key = (str(target_element), str(edge_name), str(family)) + existing = line_lookup.get(key) + if existing is None: + line_lookup[key] = FluorescenceLineEstimate( + element=str(target_element), + edge=str(edge_name), + family=str(family), + line_energy_ev=float(emitted_energy_ev), + primary_detected_yield=0.0, + secondary_detected_yield=float(secondary_detected), + ) + else: + existing.secondary_detected_yield = float( + existing.secondary_detected_yield + + secondary_detected + ) + + line_estimates = list(line_lookup.values()) + return FluorescenceEstimate( + incident_energy_kev=float(settings.beam.incident_energy_kev), + capillary_geometry=_normalize_geometry_name( + settings.beam.capillary_geometry + ), + total_primary_detected_yield=float( + sum(item.primary_detected_yield for item in line_estimates) + ), + total_secondary_detected_yield=float( + sum(item.secondary_detected_yield for item in line_estimates) + ), + line_estimates=line_estimates, + ) + + +def calculate_solution_scattering_estimate( + settings: SolutionScatteringEstimatorSettings, +) -> SolutionScatteringEstimate: + path_lengths_cm = _path_lengths_cm(settings.beam) + solution_result = calculate_solution_properties(settings.solution) + + number_density_estimate = None + if settings.calculate_number_density: + number_density_estimate = _build_number_density_estimate( + solution_result + ) + + volume_fraction_estimate = None + if settings.calculate_solute_volume_fraction: + volume_fraction_estimate = calculate_solute_volume_fraction_estimate( + SoluteVolumeFractionSettings( + solution=SolutionPropertiesSettings.from_dict( + settings.solution.to_dict() + ), + solute_density_g_per_ml=settings.solute_density_g_per_ml, + solvent_density_g_per_ml=settings.solvent_density_g_per_ml, + ) + ) + + interaction_contrast_estimate = None + if volume_fraction_estimate is not None: + interaction_contrast_estimate = _calculate_saxs_interaction_estimate( + settings, + volume_fraction_estimate, + ) + + attenuation_estimate = None + if settings.calculate_solvent_scattering_contribution: + attenuation_estimate = _calculate_attenuation_estimate( + settings, + solution_result, + path_lengths_cm, + ) + + fluorescence_estimate = None + if settings.calculate_sample_fluorescence_yield: + fluorescence_estimate = _calculate_fluorescence_estimate( + settings, + solution_result, + path_lengths_cm, + ) + + return SolutionScatteringEstimate( + settings=SolutionScatteringEstimatorSettings( + solution=SolutionPropertiesSettings.from_dict( + settings.solution.to_dict() + ), + solute_density_g_per_ml=settings.solute_density_g_per_ml, + solvent_density_g_per_ml=settings.solvent_density_g_per_ml, + calculate_number_density=bool(settings.calculate_number_density), + calculate_solute_volume_fraction=bool( + settings.calculate_solute_volume_fraction + ), + calculate_solvent_scattering_contribution=bool( + settings.calculate_solvent_scattering_contribution + ), + calculate_sample_fluorescence_yield=bool( + settings.calculate_sample_fluorescence_yield + ), + beam=BeamGeometrySettings( + incident_energy_kev=float(settings.beam.incident_energy_kev), + capillary_size_mm=float(settings.beam.capillary_size_mm), + capillary_geometry=str(settings.beam.capillary_geometry), + beam_profile=str(settings.beam.beam_profile), + beam_footprint_width_mm=float( + settings.beam.beam_footprint_width_mm + ), + beam_footprint_height_mm=float( + settings.beam.beam_footprint_height_mm + ), + ), + ), + solution_result=solution_result, + number_density_estimate=number_density_estimate, + volume_fraction_estimate=volume_fraction_estimate, + interaction_contrast_estimate=interaction_contrast_estimate, + attenuation_estimate=attenuation_estimate, + fluorescence_estimate=fluorescence_estimate, + ) + + +__all__ = [ + "BEAM_PROFILE_ITEMS", + "CAPILLARY_GEOMETRY_ITEMS", + "DEFAULT_BEAM_FOOTPRINT_HEIGHT_MM", + "DEFAULT_BEAM_FOOTPRINT_WIDTH_MM", + "DEFAULT_BEAM_PROFILE", + "DEFAULT_CAPILLARY_GEOMETRY", + "DEFAULT_CAPILLARY_SIZE_MM", + "DEFAULT_INCIDENT_ENERGY_KEV", + "ENERGY_WAVELENGTH_KEV_ANGSTROM", + "AttenuationEstimate", + "BeamGeometrySettings", + "FluorescenceEstimate", + "FluorescenceLineEstimate", + "NumberDensityEstimate", + "SAXSInteractionEstimate", + "SolutionScatteringEstimate", + "SolutionScatteringEstimatorSettings", + "calculate_solution_scattering_estimate", + "wavelength_angstrom_from_energy_kev", +] diff --git a/src/saxshell/saxs/ui/branding.py b/src/saxshell/saxs/ui/branding.py new file mode 100644 index 0000000..bf2bb8e --- /dev/null +++ b/src/saxshell/saxs/ui/branding.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import sys +from functools import lru_cache +from pathlib import Path + +from PySide6.QtCore import QCoreApplication, QEvent, Qt +from PySide6.QtGui import QColor, QFont, QIcon, QPainter, QPen, QPixmap +from PySide6.QtWidgets import ( + QApplication, + QHBoxLayout, + QLabel, + QSizePolicy, + QSplashScreen, + QVBoxLayout, + QWidget, +) + +SAXSHELL_APPLICATION_NAME = "SAXSShell" +BRAND_PRIMARY_HEX = "#0f4aa6" +BRAND_SECONDARY_HEX = "#4f6074" +BRAND_ICON_MIN_SIZE = 34 +BRAND_ICON_MAX_SIZE = 34 +BRAND_TITLE_MAX_POINT_SIZE = 13.5 + + +def saxshell_icon_path() -> Path: + return ( + Path(__file__).resolve().parents[1] + / "_ui_assets" + / ("saxshell_icon.svg") + ) + + +@lru_cache(maxsize=1) +def load_saxshell_icon() -> QIcon: + return QIcon(str(saxshell_icon_path())) + + +class SAXShellBrandWidget(QWidget): + """Top-left application branding that tracks UI font scaling.""" + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.setObjectName("saxshellBrandWidget") + self.setSizePolicy( + QSizePolicy.Policy.Minimum, + QSizePolicy.Policy.Fixed, + ) + self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents) + + self._layout = QHBoxLayout(self) + self._layout.setContentsMargins(14, 2, 16, 2) + self._layout.setSpacing(10) + + self._icon_label = QLabel(self) + self._icon_label.setSizePolicy( + QSizePolicy.Policy.Fixed, + QSizePolicy.Policy.Fixed, + ) + self._layout.addWidget( + self._icon_label, + alignment=Qt.AlignmentFlag.AlignVCenter, + ) + + text_column = QVBoxLayout() + text_column.setContentsMargins(0, 0, 0, 0) + text_column.setSpacing(0) + + self._title_label = QLabel("SAXSShell", self) + self._title_label.setSizePolicy( + QSizePolicy.Policy.Minimum, + QSizePolicy.Policy.Fixed, + ) + self._title_label.setStyleSheet(f"color: {BRAND_PRIMARY_HEX};") + text_column.addWidget(self._title_label) + + self._layout.addLayout(text_column) + self._sync_brand_metrics() + + def changeEvent(self, event: QEvent) -> None: + super().changeEvent(event) + if event.type() in ( + QEvent.Type.FontChange, + QEvent.Type.StyleChange, + ): + self._sync_brand_metrics() + + def _sync_brand_metrics(self) -> None: + base_font = QFont(self.font()) + base_point_size = base_font.pointSizeF() + if base_point_size <= 0: + app_font = QApplication.font(self) + base_point_size = app_font.pointSizeF() + base_font = QFont(app_font) + if base_point_size <= 0: + base_point_size = 10.0 + + title_font = QFont(base_font) + title_font.setBold(True) + title_font.setPointSizeF( + min( + max(base_point_size * 1.18, base_point_size + 1.5), + BRAND_TITLE_MAX_POINT_SIZE, + ) + ) + self._title_label.setFont(title_font) + + icon_size = round(title_font.pointSizeF() * 2.35) + icon_size = max( + BRAND_ICON_MIN_SIZE, min(BRAND_ICON_MAX_SIZE, icon_size) + ) + self._icon_label.setPixmap( + load_saxshell_icon().pixmap(icon_size, icon_size) + ) + self._icon_label.setFixedSize(icon_size, icon_size) + + layout_size = self._layout.sizeHint() + self.setMinimumWidth(layout_size.width()) + self.setFixedHeight(layout_size.height()) + self.updateGeometry() + + +def build_saxshell_brand_widget(parent: QWidget | None = None) -> QWidget: + return SAXShellBrandWidget(parent) + + +def _configure_macos_application_identity() -> None: + if sys.platform != "darwin": + return + try: + from Foundation import NSBundle, NSProcessInfo + except Exception: + return + + NSProcessInfo.processInfo().setProcessName_(SAXSHELL_APPLICATION_NAME) + info = NSBundle.mainBundle().infoDictionary() + if info is not None: + info["CFBundleName"] = SAXSHELL_APPLICATION_NAME + info["CFBundleDisplayName"] = SAXSHELL_APPLICATION_NAME + + +def prepare_saxshell_application_identity() -> None: + QCoreApplication.setApplicationName(SAXSHELL_APPLICATION_NAME) + QApplication.setApplicationDisplayName(SAXSHELL_APPLICATION_NAME) + QApplication.setDesktopFileName(SAXSHELL_APPLICATION_NAME) + _configure_macos_application_identity() + + +def configure_saxshell_application(app: QApplication) -> None: + prepare_saxshell_application_identity() + app.setApplicationName(SAXSHELL_APPLICATION_NAME) + app.setApplicationDisplayName(SAXSHELL_APPLICATION_NAME) + app.setDesktopFileName(SAXSHELL_APPLICATION_NAME) + app.setWindowIcon(load_saxshell_icon()) + + +def create_saxshell_startup_splash() -> QSplashScreen: + pixmap = QPixmap(420, 210) + pixmap.fill(Qt.GlobalColor.transparent) + + painter = QPainter(pixmap) + painter.setRenderHint(QPainter.RenderHint.Antialiasing, True) + painter.setPen(Qt.PenStyle.NoPen) + painter.setBrush(QColor("#f7f9fc")) + painter.drawRoundedRect(10, 10, 400, 190, 18, 18) + + border_pen = QPen(QColor("#d7e1f0")) + border_pen.setWidth(2) + painter.setPen(border_pen) + painter.setBrush(Qt.BrushStyle.NoBrush) + painter.drawRoundedRect(10, 10, 400, 190, 18, 18) + + icon_pixmap = load_saxshell_icon().pixmap(92, 92) + painter.drawPixmap(26, 48, icon_pixmap) + + title_font = QFont() + title_font.setBold(True) + title_font.setPointSize(19) + painter.setFont(title_font) + painter.setPen(QColor(BRAND_PRIMARY_HEX)) + painter.drawText(136, 84, "SAXSShell") + + subtitle_font = QFont() + subtitle_font.setPointSize(10) + painter.setFont(subtitle_font) + painter.setPen(QColor(BRAND_SECONDARY_HEX)) + painter.drawText(138, 113, "Loading SAXS workflow...") + painter.drawText(138, 136, "Initializing interface and project state") + painter.end() + + splash = QSplashScreen( + pixmap, + Qt.WindowType.SplashScreen | Qt.WindowType.FramelessWindowHint, + ) + splash.setWindowIcon(load_saxshell_icon()) + splash.showMessage( + "Starting SAXSShell", + Qt.AlignmentFlag.AlignHCenter | Qt.AlignmentFlag.AlignBottom, + QColor(BRAND_PRIMARY_HEX), + ) + return splash diff --git a/src/saxshell/saxs/ui/distribution_window.py b/src/saxshell/saxs/ui/distribution_window.py index c5a6084..41b0847 100644 --- a/src/saxshell/saxs/ui/distribution_window.py +++ b/src/saxshell/saxs/ui/distribution_window.py @@ -4,6 +4,7 @@ import json import math import re +from dataclasses import dataclass import numpy as np from matplotlib.backends.backend_qtagg import ( @@ -11,7 +12,7 @@ NavigationToolbar2QT, ) from matplotlib.figure import Figure -from PySide6.QtCore import Signal +from PySide6.QtCore import Qt, Signal from PySide6.QtWidgets import ( QAbstractItemView, QCheckBox, @@ -21,6 +22,8 @@ QMainWindow, QMessageBox, QPushButton, + QScrollArea, + QSplitter, QTableWidget, QTableWidgetItem, QTextEdit, @@ -49,8 +52,8 @@ ("Very Lenient", "very_lenient"), ) SMART_PRIOR_APPLY_SCOPE_ITEMS: tuple[tuple[str, str], ...] = ( - ("All Structures", "all"), - ("Selected Structures", "selected"), + ("All Parameters", "all"), + ("Selected Parameters", "selected"), ) SMART_PRIOR_SPREAD_FACTORS: dict[str, float] = { "very_strict": 0.4, @@ -62,58 +65,214 @@ SMART_PRIOR_STATUS_LABELS: dict[str, str] = { value: label for label, value in SMART_PRIOR_INDIVIDUAL_STATUS_ITEMS } +GUIDE_INTERVAL_LOWER_Q = float(stats.norm.cdf(-3.0)) +GUIDE_INTERVAL_UPPER_Q = float(stats.norm.cdf(3.0)) +GUIDE_LOW_COLUMN = 9 +GUIDE_HIGH_COLUMN = 10 +RESET_COLUMN = 11 +PLOT_DOMAIN_LOWER_Q = 1e-6 +PLOT_DOMAIN_UPPER_Q = 1.0 - PLOT_DOMAIN_LOWER_Q +PLOT_RELATIVE_DENSITY_THRESHOLD = 0.01 +PLOT_PADDING_FRACTION = 0.08 +PLOT_WINDOW_MARGIN_FRACTION = 0.12 +PLOT_SAMPLE_COUNT = 600 +PLOT_LOG_SCALE_SPAN_RATIO = 25.0 +GUIDE_INTERVAL_SIGMA = 3.0 +INTERACTIVE_PARAMETER_EPSILON = 1e-6 +INTERACTIVE_MAX_LOGNORM_SHAPE = 5.0 +INTERACTIVE_PEAK_DRAG_SENSITIVITY = 3.0 +INTERACTIVE_PEAK_HANDLE_SIZE = 90 +INTERACTIVE_CENTER_HANDLE_SIZE = 78 +INTERACTIVE_WIDTH_HANDLE_SIZE = 70 +INTERACTIVE_CENTER_HANDLE_Y_FRACTION = 0.88 +INTERACTIVE_WIDTH_HANDLE_Y_FRACTION = 0.18 + + +@dataclass(slots=True) +class _InteractiveHandleArtists: + axis: object + peak: object | None = None + center: object | None = None + left_width: object | None = None + right_width: object | None = None + + +@dataclass(slots=True) +class _InteractiveDragState: + row: int + kind: str + start_entry: DreamParameterEntry + preview_entry: DreamParameterEntry + start_y: float + x_limits: tuple[float, float] + y_limits: tuple[float, float] + x_scale: str + + +@dataclass(slots=True) +class _PlotWindowState: + row: int + x_limits: tuple[float, float] + y_limits: tuple[float, float] + x_scale: str class WeightDistributionPreviewWindow(QMainWindow): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) + self._entries: list[DreamParameterEntry] = [] + self._parameter_checkboxes: list[ + tuple[DreamParameterEntry, QCheckBox] + ] = [] self._build_ui() def _build_ui(self) -> None: - self.setWindowTitle("DREAM Weight Prior Preview") - self.resize(900, 620) + self.setWindowTitle("DREAM Prior Preview") + self.resize(1080, 640) central = QWidget() - layout = QVBoxLayout(central) + root = QHBoxLayout(central) + + controls = QWidget() + controls.setMinimumWidth(240) + controls_layout = QVBoxLayout(controls) + controls_layout.addWidget(QLabel("Visible parameters")) + helper = QLabel("Only w<##> priors are enabled by default.") + helper.setWordWrap(True) + controls_layout.addWidget(helper) + self.parameter_scroll = QScrollArea() + self.parameter_scroll.setWidgetResizable(True) + self.parameter_scroll.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAlwaysOff + ) + self.parameter_scroll.setMinimumWidth(240) + self.parameter_checkbox_container = QWidget() + self.parameter_checkbox_layout = QVBoxLayout( + self.parameter_checkbox_container + ) + self.parameter_checkbox_layout.setContentsMargins(0, 0, 0, 0) + self.parameter_checkbox_layout.setSpacing(4) + self.parameter_checkbox_layout.addStretch(1) + self.parameter_scroll.setWidget(self.parameter_checkbox_container) + controls_layout.addWidget(self.parameter_scroll) + + plot_panel = QWidget() + layout = QVBoxLayout(plot_panel) self.figure = Figure(figsize=(8, 5)) self.canvas = FigureCanvasQTAgg(self.figure) self.toolbar = NavigationToolbar2QT(self.canvas, self) layout.addWidget(self.toolbar) layout.addWidget(self.canvas) + root.addWidget(controls, stretch=1) + root.addWidget(plot_panel, stretch=3) self.setCentralWidget(central) def plot_entries(self, entries: list[DreamParameterEntry]) -> None: + self._entries = list(entries) + self._rebuild_parameter_checkboxes() + self._refresh_plot() + + @staticmethod + def _is_weight_parameter(param_name: str) -> bool: + return re.fullmatch(r"w\d+", param_name.strip()) is not None + + @staticmethod + def _entry_label(entry: DreamParameterEntry) -> str: + legend_label = entry.param.strip() or "Unnamed parameter" + if entry.structure.strip(): + legend_label = f"{legend_label} ({entry.structure.strip()})" + return legend_label + + @staticmethod + def _entry_toggle_key(entry: DreamParameterEntry) -> tuple[str, ...]: + return ( + entry.structure.strip(), + entry.motif.strip(), + entry.param_type.strip(), + entry.param.strip(), + ) + + def _rebuild_parameter_checkboxes(self) -> None: + prior_states = { + self._entry_toggle_key(entry): checkbox.isChecked() + for entry, checkbox in self._parameter_checkboxes + } + self._parameter_checkboxes = [] + while self.parameter_checkbox_layout.count(): + item = self.parameter_checkbox_layout.takeAt(0) + widget = item.widget() + if widget is not None: + widget.deleteLater() + if not self._entries: + empty_label = QLabel("No prior distributions are available.") + empty_label.setWordWrap(True) + self.parameter_checkbox_layout.addWidget(empty_label) + self.parameter_checkbox_layout.addStretch(1) + return + for entry in self._entries: + checkbox = QCheckBox(self._entry_label(entry)) + checkbox.setChecked( + prior_states.get( + self._entry_toggle_key(entry), + self._is_weight_parameter(entry.param), + ) + ) + tooltip_parts = [] + if entry.structure.strip(): + tooltip_parts.append(f"Structure: {entry.structure.strip()}") + if entry.motif.strip(): + tooltip_parts.append(f"Motif: {entry.motif.strip()}") + if tooltip_parts: + checkbox.setToolTip("\n".join(tooltip_parts)) + checkbox.toggled.connect( + lambda _checked=False: self._refresh_plot() + ) + self.parameter_checkbox_layout.addWidget(checkbox) + self._parameter_checkboxes.append((entry, checkbox)) + self.parameter_checkbox_layout.addStretch(1) + + def _refresh_plot(self) -> None: self.figure.clear() axis = self.figure.add_subplot(111) x_limits: list[tuple[float, float]] = [] plotted = 0 + selected_entries = [ + entry + for entry, checkbox in self._parameter_checkboxes + if checkbox.isChecked() + ] - for entry in entries: - distribution = getattr(stats, entry.distribution) - x_min, x_max = _distribution_domain(entry) - if not np.isfinite(x_min) or not np.isfinite(x_max): + for entry in selected_entries: + try: + distribution = getattr(stats, entry.distribution) + x_min, x_max = _distribution_domain(entry) + if not np.isfinite(x_min) or not np.isfinite(x_max): + continue + x_values = np.linspace(x_min, x_max, 300) + y_values = distribution.pdf(x_values, **entry.dist_params) + except Exception: continue - x_values = np.linspace(x_min, x_max, 300) - y_values = distribution.pdf(x_values, **entry.dist_params) if not np.all(np.isfinite(y_values)): continue - legend_label = entry.param - if entry.structure.strip(): - legend_label = f"{entry.param} ({entry.structure})" axis.plot( x_values, y_values, linewidth=1.6, - label=legend_label, + label=self._entry_label(entry), ) x_limits.append((x_min, x_max)) plotted += 1 if plotted == 0: + message = ( + "No prior distributions are currently enabled in the preview." + if not selected_entries + else "No valid prior distributions are available for the selected parameters." + ) axis.text( 0.5, 0.5, - "No valid w<##> prior distributions are available to preview.", + message, ha="center", va="center", ) @@ -121,7 +280,7 @@ def plot_entries(self, entries: list[DreamParameterEntry]) -> None: else: axis.set_xlabel("Value") axis.set_ylabel("Density") - axis.set_title("Weight prior distributions") + axis.set_title("Prior distributions") axis.legend(loc="best", fontsize="small") axis.set_xlim( min(limit[0] for limit in x_limits), @@ -146,9 +305,13 @@ def __init__( self._was_saved = False self._suppress_vary_warning = False self._suppress_status_change = False + self._reset_entries: list[DreamParameterEntry] = [] self._weight_preview_window: WeightDistributionPreviewWindow | None = ( None ) + self._interactive_handles: _InteractiveHandleArtists | None = None + self._drag_state: _InteractiveDragState | None = None + self._plot_window_state: _PlotWindowState | None = None self._build_ui() self.load_entries(entries) @@ -157,17 +320,16 @@ def _build_ui(self) -> None: self.resize(1200, 720) central = QWidget() - root = QHBoxLayout(central) + root = QVBoxLayout(central) - left = QWidget() - left_layout = QVBoxLayout(left) + self._left_panel = QWidget() + left_layout = QVBoxLayout(self._left_panel) + left_layout.setContentsMargins(0, 0, 0, 0) top_row = QHBoxLayout() - self.preview_weight_priors_button = QPushButton( - "Preview Weight Priors" - ) + self.preview_weight_priors_button = QPushButton("Preview Priors") self.preview_weight_priors_button.setToolTip( - "Open a shared density plot of all current w<##> prior " - "distributions from the table." + "Open a shared density plot of the current prior distributions. " + "Only w<##> parameters are enabled by default in the preview." ) self.preview_weight_priors_button.clicked.connect( self._show_weight_prior_preview @@ -189,9 +351,9 @@ def _build_ui(self) -> None: self.smart_prior_apply_scope_combo.addItem(label, userData=value) self.smart_prior_apply_scope_combo.setToolTip( "Choose whether the selected preset should affect every " - "structure in the table or only the currently selected " - "structure rows. Size-aware mixed presets always apply to all " - "structures so their relative ranking remains meaningful." + "parameter row in the table or only the currently selected " + "parameter rows. Size-aware mixed presets always apply across " + "the full table so their relative ranking remains meaningful." ) top_row.addWidget(self.smart_prior_apply_scope_combo) self.apply_smart_prior_preset_button = QPushButton( @@ -208,7 +370,7 @@ def _build_ui(self) -> None: top_row.addWidget(self.apply_smart_prior_preset_button) top_row.addStretch(1) left_layout.addLayout(top_row) - self.table = QTableWidget(0, 9) + self.table = QTableWidget(0, 12) self.table.setHorizontalHeaderLabels( [ "Structure", @@ -220,8 +382,22 @@ def _build_ui(self) -> None: "Distribution", "Distribution Params", "Smart Preset Status", + "Guide Low", + "Guide High", + "Reset", ] ) + guide_tooltip = ( + "Practical prior bounds for the current distribution. " + "Bounded priors use their exact support, while unbounded priors " + "use a central 99.73% interval (Gaussian 3sigma equivalent)." + ) + low_header = self.table.horizontalHeaderItem(GUIDE_LOW_COLUMN) + if low_header is not None: + low_header.setToolTip(guide_tooltip) + high_header = self.table.horizontalHeaderItem(GUIDE_HIGH_COLUMN) + if high_header is not None: + high_header.setToolTip(guide_tooltip) self.table.setSelectionBehavior( QAbstractItemView.SelectionBehavior.SelectRows ) @@ -229,8 +405,8 @@ def _build_ui(self) -> None: QAbstractItemView.SelectionMode.ExtendedSelection ) self.table.cellClicked.connect(self._on_row_selected) + self.table.currentCellChanged.connect(self._on_current_cell_changed) self.table.cellChanged.connect(self._on_table_changed) - left_layout.addWidget(self.table) button_row = QHBoxLayout() self.select_recommended_vary_button = QPushButton( @@ -258,21 +434,81 @@ def _build_ui(self) -> None: save_button.clicked.connect(self._emit_saved) button_row.addWidget(save_button) button_row.addStretch(1) - left_layout.addLayout(button_row) self.console = QTextEdit() self.console.setReadOnly(True) self.console.setMinimumHeight(160) - left_layout.addWidget(self.console) + self._editor_panel = QWidget() + editor_layout = QVBoxLayout(self._editor_panel) + editor_layout.setContentsMargins(0, 0, 0, 0) + editor_layout.addWidget(self.table, stretch=1) + editor_layout.addLayout(button_row) + + self._left_splitter = QSplitter(Qt.Orientation.Vertical) + self._left_splitter.setChildrenCollapsible(False) + self._left_splitter.setHandleWidth(10) + self._left_splitter.addWidget(self._editor_panel) + self._left_splitter.addWidget(self.console) + self._left_splitter.setStretchFactor(0, 5) + self._left_splitter.setStretchFactor(1, 2) + self._left_splitter.setSizes([520, 180]) - right = QWidget() - right_layout = QVBoxLayout(right) - self.figure = Figure(figsize=(6, 5)) + left_layout.addWidget(self._left_splitter, stretch=1) + + self._plot_panel = QWidget() + right_layout = QVBoxLayout(self._plot_panel) + right_layout.setContentsMargins(0, 0, 0, 0) + interaction_row = QHBoxLayout() + interaction_row.addWidget(QLabel("Interactive plot editing")) + self.rescale_axes_button = QPushButton("Rescale Axes") + self.rescale_axes_button.setToolTip( + "Refit the x- and y-axis limits to the currently selected prior." + ) + self.rescale_axes_button.clicked.connect(self._rescale_current_plot) + interaction_row.addWidget(self.rescale_axes_button) + interaction_row.addStretch(1) + self.lock_center_checkbox = QCheckBox("Lock center") + self.lock_center_checkbox.setChecked(True) + self.lock_center_checkbox.setToolTip( + "Keep the prior center fixed while dragging width and peak " + "handles. Uncheck this to enable dragging the red center " + "marker." + ) + self.lock_center_checkbox.toggled.connect(self._on_center_lock_toggled) + interaction_row.addWidget(self.lock_center_checkbox) + right_layout.addLayout(interaction_row) + self.interactive_hint_label = QLabel() + self.interactive_hint_label.setWordWrap(True) + right_layout.addWidget(self.interactive_hint_label) + self._refresh_interactive_hint() + self.figure = Figure(figsize=(6, 6)) self.canvas = FigureCanvasQTAgg(self.figure) + self.canvas.mpl_connect( + "button_press_event", self._on_plot_mouse_press + ) + self.canvas.mpl_connect( + "motion_notify_event", self._on_plot_mouse_move + ) + self.canvas.mpl_connect( + "button_release_event", self._on_plot_mouse_release + ) + self.canvas.mpl_connect( + "figure_leave_event", self._on_plot_mouse_leave + ) + self.toolbar = NavigationToolbar2QT(self.canvas, self) + right_layout.addWidget(self.toolbar) right_layout.addWidget(self.canvas) - root.addWidget(left, stretch=3) - root.addWidget(right, stretch=2) + self._main_splitter = QSplitter(Qt.Orientation.Horizontal) + self._main_splitter.setChildrenCollapsible(False) + self._main_splitter.setHandleWidth(10) + self._main_splitter.addWidget(self._left_panel) + self._main_splitter.addWidget(self._plot_panel) + self._main_splitter.setStretchFactor(0, 4) + self._main_splitter.setStretchFactor(1, 3) + self._main_splitter.setSizes([720, 560]) + + root.addWidget(self._main_splitter, stretch=1) self.setCentralWidget(central) def load_entries( @@ -280,18 +516,24 @@ def load_entries( entries: list[DreamParameterEntry], *, has_existing_parameter_map: bool | None = None, + update_reset_entries: bool = True, ) -> None: if has_existing_parameter_map is not None: self._has_existing_parameter_map = bool(has_existing_parameter_map) self._was_saved = False + normalized_entries = [ + self._normalized_entry_copy(entry) for entry in entries + ] + self._plot_window_state = None + if update_reset_entries: + self._reset_entries = [ + self._normalized_entry_copy(entry) + for entry in normalized_entries + ] self.table.blockSignals(True) - self.table.setRowCount(len(entries)) - for row, entry in enumerate(entries): - params = self._normalize_distribution_params( - entry.distribution, - entry.dist_params, - entry.value, - ) + self.table.setRowCount(len(normalized_entries)) + for row, entry in enumerate(normalized_entries): + params = dict(entry.dist_params) self.table.setItem(row, 0, QTableWidgetItem(entry.structure)) self.table.setItem(row, 1, QTableWidgetItem(entry.motif)) self.table.setItem(row, 2, QTableWidgetItem(entry.param_type)) @@ -344,41 +586,103 @@ def load_entries( ) ) self.table.setCellWidget(row, 8, status_combo) + self._refresh_distribution_guides_for_row( + row, + entry=entry, + ) + reset_button = QPushButton("Reset") + reset_button.setToolTip( + "Reset this prior row to the most recently loaded or saved " + "parameter-map values." + ) + reset_button.clicked.connect( + lambda _checked=False, selected_row=row: ( + self._reset_row_to_baseline(selected_row) + ) + ) + self.table.setCellWidget(row, RESET_COLUMN, reset_button) self.table.blockSignals(False) self.table.resizeColumnsToContents() - self._entries = entries - if entries: - self._plot_entry(entries[0]) + self._entries = normalized_entries + if normalized_entries: + self.table.setCurrentCell(0, 0) + self._plot_entry( + normalized_entries[0], + row=0, + force_rescale=True, + ) + return + self.figure.clear() + self._interactive_handles = None + self.canvas.draw() + + def _entry_from_row(self, row: int) -> DreamParameterEntry: + distribution_widget = self.table.cellWidget(row, 6) + vary_widget = self.table.cellWidget(row, 5) + distribution = ( + distribution_widget.currentText() + if isinstance(distribution_widget, QComboBox) + else "lognorm" + ) + value_item = self.table.item(row, 4) + params_item = self.table.item(row, 7) + status_value = self._row_smart_status(row) + value = float(value_item.text()) if value_item is not None else 0.0 + params = self._normalize_distribution_params( + distribution, + self._parse_params(params_item.text() if params_item else "{}"), + value, + ) + return DreamParameterEntry( + structure=self.table.item(row, 0).text(), + motif=self.table.item(row, 1).text(), + param_type=self.table.item(row, 2).text(), + param=self.table.item(row, 3).text(), + value=value, + vary=( + vary_widget.isChecked() + if isinstance(vary_widget, QCheckBox) + else False + ), + distribution=distribution, + dist_params=params, + smart_preset_status=status_value, + ) def current_entries(self) -> list[DreamParameterEntry]: - entries: list[DreamParameterEntry] = [] - for row in range(self.table.rowCount()): - distribution = self.table.cellWidget(row, 6).currentText() - value = float(self.table.item(row, 4).text()) - params = self._normalize_distribution_params( - distribution, - self._parse_params(self.table.item(row, 7).text()), - value, - ) - entries.append( - DreamParameterEntry( - structure=self.table.item(row, 0).text(), - motif=self.table.item(row, 1).text(), - param_type=self.table.item(row, 2).text(), - param=self.table.item(row, 3).text(), - value=value, - vary=self.table.cellWidget(row, 5).isChecked(), - distribution=distribution, - dist_params=params, - smart_preset_status=self._row_smart_status(row), - ) - ) - return entries + return [ + self._entry_from_row(row) for row in range(self.table.rowCount()) + ] + + def _normalized_entry_copy( + self, + entry: DreamParameterEntry, + ) -> DreamParameterEntry: + return DreamParameterEntry( + structure=str(entry.structure), + motif=str(entry.motif), + param_type=str(entry.param_type), + param=str(entry.param), + value=float(entry.value), + vary=bool(entry.vary), + distribution=str(entry.distribution), + dist_params=self._normalize_distribution_params( + str(entry.distribution), + dict(entry.dist_params), + float(entry.value), + ), + smart_preset_status=self._normalized_smart_status( + getattr(entry, "smart_preset_status", "custom") + ), + ) def _emit_saved(self) -> None: entries = self.current_entries() self._was_saved = True self._has_existing_parameter_map = True + self._reset_entries = [ + self._normalized_entry_copy(entry) for entry in entries + ] self.saved.emit(entries) self.console.append("Saved current DREAM parameter map.") QMessageBox.information( @@ -408,40 +712,149 @@ def closeEvent(self, event) -> None: # type: ignore[override] event.ignore() def _on_row_selected(self, row: int, _column: int) -> None: - self._plot_entry(self.current_entries()[row]) + self._plot_entry( + self._entry_from_row(row), + row=row, + force_rescale=True, + ) - def _on_distribution_changed(self, row: int) -> None: - distribution = self.table.cellWidget(row, 6).currentText() - value = float(self.table.item(row, 4).text()) - params = self._normalize_distribution_params( - distribution, - self._parse_params(self.table.item(row, 7).text()), - value, + def _on_current_cell_changed( + self, + current_row: int, + _current_column: int, + _previous_row: int, + _previous_column: int, + ) -> None: + if current_row < 0 or current_row >= self.table.rowCount(): + return + self._plot_entry( + self._entry_from_row(current_row), + row=current_row, + force_rescale=True, ) + + def _on_distribution_changed(self, row: int) -> None: + entry = self._entry_from_row(row) self.table.blockSignals(True) - self.table.item(row, 7).setText(json.dumps(params, sort_keys=True)) + self.table.item(row, 7).setText( + json.dumps(entry.dist_params, sort_keys=True) + ) self.table.blockSignals(False) - entry = self.current_entries()[row] + self._refresh_distribution_guides_for_row(row, entry=entry) self.console.append( f"Distribution for {entry.param} set to {entry.distribution}." ) self._set_group_status_for_row(row, "custom") - self._plot_entry(entry) + self._plot_entry(entry, row=row, force_rescale=True) def _on_table_changed(self, row: int, column: int) -> None: if column == 7: try: - entry = self.current_entries()[row] + entry = self._entry_from_row(row) except Exception as exc: + self._refresh_distribution_guides_for_row(row, entry=None) self.console.append( f"Invalid distribution parameter JSON: {exc}" ) return + self._refresh_distribution_guides_for_row(row, entry=entry) self._set_group_status_for_row(row, "custom") - self._plot_entry(entry) + if row == self.table.currentRow(): + self._plot_entry(entry, row=row) return if column == 4: + try: + entry = self._entry_from_row(row) + except Exception: + entry = None + self._refresh_distribution_guides_for_row(row, entry=entry) self._set_group_status_for_row(row, "custom") + if entry is not None and row == self.table.currentRow(): + self._plot_entry(entry, row=row) + + def _reset_row_to_baseline(self, row: int) -> None: + if row < 0 or row >= len(self._reset_entries): + QMessageBox.warning( + self, + "Reset prior failed", + "No baseline prior entry is available for the selected row.", + ) + return + self._apply_entry_to_row(row, self._reset_entries[row]) + self.console.append( + "Reset prior row to the loaded/saved baseline: " + f"{self._row_status_label(row)}." + ) + + def _apply_entry_to_row( + self, + row: int, + entry: DreamParameterEntry, + ) -> None: + if row < 0 or row >= self.table.rowCount(): + return + normalized_entry = self._normalized_entry_copy(entry) + distribution_combo = self.table.cellWidget(row, 6) + vary_box = self.table.cellWidget(row, 5) + status_combo = self.table.cellWidget(row, 8) + was_table_blocked = self.table.blockSignals(True) + prior_vary_suppressed = self._suppress_vary_warning + prior_status_suppressed = self._suppress_status_change + self._suppress_vary_warning = True + self._suppress_status_change = True + try: + self._set_editable_table_item(row, 0, normalized_entry.structure) + self._set_editable_table_item(row, 1, normalized_entry.motif) + self._set_editable_table_item(row, 2, normalized_entry.param_type) + self._set_editable_table_item(row, 3, normalized_entry.param) + self._set_editable_table_item( + row, + 4, + f"{normalized_entry.value:.6g}", + ) + if isinstance(vary_box, QCheckBox): + vary_box.blockSignals(True) + vary_box.setChecked(normalized_entry.vary) + vary_box.blockSignals(False) + if isinstance(distribution_combo, QComboBox): + distribution_combo.blockSignals(True) + distribution_combo.setCurrentText( + normalized_entry.distribution + ) + distribution_combo.blockSignals(False) + self._set_editable_table_item( + row, + 7, + json.dumps(normalized_entry.dist_params, sort_keys=True), + ) + if isinstance(status_combo, QComboBox): + status_index = status_combo.findData( + normalized_entry.smart_preset_status + ) + if status_index < 0: + status_index = status_combo.findData("custom") + status_combo.blockSignals(True) + status_combo.setCurrentIndex(max(status_index, 0)) + status_combo.blockSignals(False) + finally: + self._suppress_vary_warning = prior_vary_suppressed + self._suppress_status_change = prior_status_suppressed + self.table.blockSignals(was_table_blocked) + self._refresh_distribution_guides_for_row(row, entry=normalized_entry) + if row == self.table.currentRow(): + self._plot_entry(normalized_entry, row=row) + + def _set_editable_table_item( + self, + row: int, + column: int, + text: str, + ) -> None: + item = self.table.item(row, column) + if item is None: + item = QTableWidgetItem() + self.table.setItem(row, column, item) + item.setText(text) def _set_all_vary(self, enabled: bool) -> None: self._set_vary_state_for_rows(lambda _row_index, _param_name: enabled) @@ -492,6 +905,7 @@ def _apply_smart_prior_preset(self) -> None: self.load_entries( updated_entries, has_existing_parameter_map=self._has_existing_parameter_map, + update_reset_entries=False, ) preset_label = str( self.smart_prior_preset_combo.currentText() or preset_mode @@ -503,7 +917,7 @@ def _apply_smart_prior_preset(self) -> None: "strict_small_lenient_large", "lenient_small_strict_large", }: - scope_label = "All Structures" + scope_label = "All Parameters" self.console.append( f"Applied smart prior preset: {preset_label} ({scope_label})." ) @@ -558,42 +972,358 @@ def _on_vary_toggled(self, row: int, checked: bool) -> None: ) def _show_weight_prior_preview(self) -> None: - weight_entries = [ - entry - for entry in self.current_entries() - if re.fullmatch(r"w\d+", entry.param.strip()) - ] - if not weight_entries: + entries = self.current_entries() + if not entries: QMessageBox.information( self, - "No weight priors available", - "No w<##> prior distributions are currently available in the table.", + "No priors available", + "No prior distributions are currently available in the table.", ) return if self._weight_preview_window is None: self._weight_preview_window = WeightDistributionPreviewWindow() - self._weight_preview_window.plot_entries(weight_entries) + self._weight_preview_window.plot_entries(entries) self._weight_preview_window.show() self._weight_preview_window.raise_() self._weight_preview_window.activateWindow() self.console.append( - "Opened shared preview for all current w<##> prior distributions." + "Opened shared prior preview with only w<##> parameters enabled by default." + ) + + def _refresh_interactive_hint(self) -> None: + if self.lock_center_checkbox.isChecked(): + message = ( + "Drag the orange peak handle to sharpen or broaden the " + "distribution, or drag the blue side handles to resize the " + "width. The center is currently locked. The gray dashed " + "curve shows the reset baseline." + ) + else: + message = ( + "Drag the orange peak handle to adjust peak height and " + "width, drag the blue side handles to resize the width, or " + "drag the red center handle to reposition the prior. The " + "gray dashed curve shows the reset baseline." + ) + self.interactive_hint_label.setText(message) + + def _on_center_lock_toggled(self, _checked: bool) -> None: + if ( + self._drag_state is not None + and self._drag_state.kind == "center" + and self.lock_center_checkbox.isChecked() + ): + self._drag_state = None + self._refresh_interactive_hint() + current_row = self.table.currentRow() + if current_row < 0 or current_row >= self.table.rowCount(): + return + try: + self._plot_entry( + self._entry_from_row(current_row), + row=current_row, + ) + except Exception: + return + + def _rescale_current_plot(self) -> None: + current_row = self.table.currentRow() + if current_row < 0 or current_row >= self.table.rowCount(): + return + try: + self._plot_entry( + self._entry_from_row(current_row), + row=current_row, + force_rescale=True, + ) + except Exception: + return + + def _on_plot_mouse_press(self, event) -> None: + if event is None or event.button != 1 or event.inaxes is None: + return + if self._interactive_handles is None: + return + if event.inaxes is not self._interactive_handles.axis: + return + handle_kind = self._interactive_handle_kind_at_event(event) + if handle_kind is None: + return + row = self.table.currentRow() + if row < 0 or row >= self.table.rowCount(): + return + try: + start_entry = self._entry_from_row(row) + except Exception: + return + y_limits = event.inaxes.get_ylim() + self._drag_state = _InteractiveDragState( + row=row, + kind=handle_kind, + start_entry=DreamParameterEntry.from_dict(start_entry.to_dict()), + preview_entry=DreamParameterEntry.from_dict(start_entry.to_dict()), + start_y=self._clamped_event_y( + event.ydata, + y_limits=y_limits, + ), + x_limits=tuple(float(limit) for limit in event.inaxes.get_xlim()), + y_limits=tuple(float(limit) for limit in y_limits), + x_scale=str(event.inaxes.get_xscale() or "linear"), + ) + + def _on_plot_mouse_move(self, event) -> None: + if self._drag_state is None or event is None: + return + if event.inaxes is None: + return + preview_entry = self._interactive_drag_preview_entry( + self._drag_state, + event, + ) + if preview_entry is None: + return + self._drag_state.preview_entry = preview_entry + self._plot_entry(preview_entry, row=self._drag_state.row) + + def _on_plot_mouse_release(self, _event) -> None: + if self._drag_state is None: + return + drag_state = self._drag_state + self._drag_state = None + if self._entries_match( + drag_state.start_entry, + drag_state.preview_entry, + ): + self._plot_entry(drag_state.start_entry, row=drag_state.row) + return + self._apply_entry_to_row(drag_state.row, drag_state.preview_entry) + self._set_group_status_for_row(drag_state.row, "custom") + self.console.append( + self._interactive_drag_message( + drag_state.start_entry, + drag_state.preview_entry, + handle_kind=drag_state.kind, + ) + ) + + def _on_plot_mouse_leave(self, _event) -> None: + self._on_plot_mouse_release(None) + + def _interactive_handle_kind_at_event(self, event) -> str | None: + if self._interactive_handles is None: + return None + for handle_kind, artist in ( + ("peak", self._interactive_handles.peak), + ("center", self._interactive_handles.center), + ("left_width", self._interactive_handles.left_width), + ("right_width", self._interactive_handles.right_width), + ): + if artist is None: + continue + contains, _details = artist.contains(event) + if contains: + return handle_kind + return None + + def _interactive_drag_preview_entry( + self, + drag_state: _InteractiveDragState, + event, + ) -> DreamParameterEntry | None: + if drag_state.kind == "peak": + target_y = self._clamped_event_y( + event.ydata, + y_limits=drag_state.y_limits, + ) + return self._peak_drag_adjusted_entry( + drag_state.start_entry, + start_y=drag_state.start_y, + target_y=target_y, + y_limits=drag_state.y_limits, + ) + target_x = self._clamped_event_x( + event.xdata, + x_limits=drag_state.x_limits, + x_scale=drag_state.x_scale, ) + if target_x is None: + return None + if drag_state.kind == "center": + if self.lock_center_checkbox.isChecked(): + return None + return self._center_drag_adjusted_entry( + drag_state.start_entry, + target_center=target_x, + ) + if drag_state.kind in {"left_width", "right_width"}: + return self._width_drag_adjusted_entry( + drag_state.start_entry, + handle_kind=drag_state.kind, + target_x=target_x, + ) + return None + + def _baseline_entry_for_row( + self, + row: int, + ) -> DreamParameterEntry | None: + if row < 0 or row >= len(self._reset_entries): + return None + return self._normalized_entry_copy(self._reset_entries[row]) - def _plot_entry(self, entry: DreamParameterEntry) -> None: + def _current_plot_window_state( + self, + row: int, + ) -> _PlotWindowState | None: + if ( + row < 0 + or self._plot_window_state is None + or self._plot_window_state.row != row + ): + return None + if self.figure.axes: + axis = self.figure.axes[0] + return _PlotWindowState( + row=row, + x_limits=tuple(float(limit) for limit in axis.get_xlim()), + y_limits=tuple(float(limit) for limit in axis.get_ylim()), + x_scale=str(axis.get_xscale() or "linear"), + ) + return self._plot_window_state + + @staticmethod + def _plot_window_requires_rescale( + current_state: _PlotWindowState | None, + *, + preferred_x_limits: tuple[float, float], + preferred_y_limits: tuple[float, float], + x_scale: str, + ) -> bool: + if current_state is None or current_state.x_scale != x_scale: + return True + current_x_low, current_x_high = current_state.x_limits + preferred_x_low, preferred_x_high = preferred_x_limits + current_y_low, current_y_high = current_state.y_limits + preferred_y_low, preferred_y_high = preferred_y_limits + x_span = max(abs(current_x_high - current_x_low), 1.0) + y_span = max(abs(current_y_high - current_y_low), 1.0) + x_tolerance = max(x_span * 1e-9, 1e-9) + y_tolerance = max(y_span * 1e-9, 1e-9) + return bool( + preferred_x_low < current_x_low - x_tolerance + or preferred_x_high > current_x_high + x_tolerance + or preferred_y_low < current_y_low - y_tolerance + or preferred_y_high > current_y_high + y_tolerance + ) + + def _plot_entry( + self, + entry: DreamParameterEntry, + *, + row: int | None = None, + force_rescale: bool = False, + ) -> None: + plot_row = row if row is not None else int(self.table.currentRow()) + current_window = self._current_plot_window_state(plot_row) self.figure.clear() + self._interactive_handles = None axis = self.figure.add_subplot(111) try: - distribution = getattr(stats, entry.distribution) - x_min, x_max = _distribution_domain(entry) - x_values = np.linspace(x_min, x_max, 200) - y_values = distribution.pdf(x_values, **entry.dist_params) - axis.plot(x_values, y_values, color="black") + required_x_limits, required_y_limits, x_scale = ( + _distribution_plot_bounds(entry) + ) + preferred_x_limits, preferred_y_limits, _preferred_scale = ( + _distribution_plot_window(entry) + ) + if force_rescale or self._plot_window_requires_rescale( + current_window, + preferred_x_limits=required_x_limits, + preferred_y_limits=required_y_limits, + x_scale=x_scale, + ): + plot_window = _PlotWindowState( + row=plot_row, + x_limits=preferred_x_limits, + y_limits=preferred_y_limits, + x_scale=x_scale, + ) + else: + plot_window = current_window + assert plot_window is not None + x_values, y_values = _distribution_plot_curve( + entry, + x_limits=plot_window.x_limits, + x_scale=plot_window.x_scale, + ) + baseline_entry = self._baseline_entry_for_row(plot_row) + baseline_series: tuple[np.ndarray, np.ndarray] | None = None + if baseline_entry is not None: + try: + baseline_series = _distribution_plot_curve( + baseline_entry, + x_limits=plot_window.x_limits, + x_scale=plot_window.x_scale, + ) + except Exception: + baseline_series = None + axis.set_box_aspect(1.0) + axis.set_xscale(plot_window.x_scale) + if baseline_series is not None: + (baseline_line,) = axis.plot( + baseline_series[0], + baseline_series[1], + color="tab:gray", + linestyle="--", + linewidth=1.4, + alpha=0.45, + zorder=1, + ) + baseline_line.set_gid("reset-baseline") + (current_line,) = axis.plot( + x_values, + y_values, + color="black", + linewidth=1.8, + zorder=3, + ) + current_line.set_gid("current-distribution") axis.axvline(entry.value, color="tab:red", linestyle="--") + guide_low, guide_high = self._interactive_width_handle_positions( + entry + ) + if guide_low is not None: + axis.axvline( + guide_low, + color="tab:gray", + linestyle=":", + linewidth=1.0, + ) + if guide_high is not None and not np.isclose( + guide_high, + guide_low, + ): + axis.axvline( + guide_high, + color="tab:gray", + linestyle=":", + linewidth=1.0, + ) axis.set_title(f"{entry.param}: {entry.distribution}") axis.set_xlabel("Value") axis.set_ylabel("Density") + axis.set_xlim(*plot_window.x_limits) + axis.set_ylim(*plot_window.y_limits) + self._interactive_handles = self._draw_interactive_handles( + axis, + entry, + x_values=x_values, + y_values=y_values, + x_scale=plot_window.x_scale, + y_limits=plot_window.y_limits, + ) + self._plot_window_state = plot_window except Exception as exc: + self._plot_window_state = None axis.text( 0.5, 0.5, @@ -603,7 +1333,154 @@ def _plot_entry(self, entry: DreamParameterEntry) -> None: ) axis.set_axis_off() self.figure.tight_layout() - self.canvas.draw() + if self._drag_state is None: + self.canvas.draw() + else: + self.canvas.draw_idle() + + def _draw_interactive_handles( + self, + axis, + entry: DreamParameterEntry, + *, + x_values: np.ndarray, + y_values: np.ndarray, + x_scale: str, + y_limits: tuple[float, float], + ) -> _InteractiveHandleArtists: + top_y = float(y_limits[1]) + width_y = max(top_y * INTERACTIVE_WIDTH_HANDLE_Y_FRACTION, 1e-9) + center_y = max(top_y * INTERACTIVE_CENTER_HANDLE_Y_FRACTION, 1e-9) + peak_index = int(np.argmax(y_values)) + peak_artist = axis.scatter( + [float(x_values[peak_index])], + [float(y_values[peak_index])], + s=INTERACTIVE_PEAK_HANDLE_SIZE, + marker="^", + facecolor="tab:orange", + edgecolor="black", + zorder=6, + ) + + left_width_artist = None + right_width_artist = None + width_low, width_high = self._interactive_width_handle_positions(entry) + if _x_coordinate_is_valid_for_scale(width_low, x_scale): + left_width_artist = axis.scatter( + [float(width_low)], + [width_y], + s=INTERACTIVE_WIDTH_HANDLE_SIZE, + marker="s", + facecolor="tab:blue", + edgecolor="white", + zorder=6, + ) + if _x_coordinate_is_valid_for_scale(width_high, x_scale): + right_width_artist = axis.scatter( + [float(width_high)], + [width_y], + s=INTERACTIVE_WIDTH_HANDLE_SIZE, + marker="s", + facecolor="tab:blue", + edgecolor="white", + zorder=6, + ) + + center_artist = None + if _x_coordinate_is_valid_for_scale(float(entry.value), x_scale): + if self.lock_center_checkbox.isChecked(): + axis.scatter( + [float(entry.value)], + [center_y], + s=INTERACTIVE_CENTER_HANDLE_SIZE, + marker="o", + facecolor="white", + edgecolor="tab:red", + alpha=0.55, + zorder=6, + ) + else: + center_artist = axis.scatter( + [float(entry.value)], + [center_y], + s=INTERACTIVE_CENTER_HANDLE_SIZE, + marker="o", + facecolor="tab:red", + edgecolor="white", + zorder=6, + ) + return _InteractiveHandleArtists( + axis=axis, + peak=peak_artist, + center=center_artist, + left_width=left_width_artist, + right_width=right_width_artist, + ) + + @staticmethod + def _clamped_event_y( + y_value: float | None, + *, + y_limits: tuple[float, float], + ) -> float: + lower, upper = sorted(float(limit) for limit in y_limits) + if y_value is None or not np.isfinite(y_value): + return lower + return min(max(float(y_value), lower), upper) + + @staticmethod + def _clamped_event_x( + x_value: float | None, + *, + x_limits: tuple[float, float], + x_scale: str, + ) -> float | None: + if x_value is None or not np.isfinite(x_value): + return None + lower, upper = sorted(float(limit) for limit in x_limits) + bounded_value = min(max(float(x_value), lower), upper) + if x_scale == "log": + bounded_value = max(bounded_value, np.finfo(float).tiny) + return bounded_value + + @staticmethod + def _entries_match( + previous: DreamParameterEntry, + current: DreamParameterEntry, + ) -> bool: + return ( + math.isclose(float(previous.value), float(current.value)) + and str(previous.distribution) == str(current.distribution) + and previous.vary == current.vary + and all( + math.isclose( + float(previous.dist_params.get(key, float("nan"))), + float(current.dist_params.get(key, float("nan"))), + rel_tol=1e-9, + abs_tol=1e-12, + ) + for key in set(previous.dist_params) | set(current.dist_params) + ) + ) + + def _interactive_drag_message( + self, + previous: DreamParameterEntry, + current: DreamParameterEntry, + *, + handle_kind: str, + ) -> str: + del previous + action = { + "peak": "Adjusted prior peak height and width", + "center": "Moved prior center", + "left_width": "Adjusted prior width from the low-side handle", + "right_width": "Adjusted prior width from the high-side handle", + }.get(handle_kind, "Adjusted prior") + return ( + f"{action}: {current.param} -> value={current.value:.6g}, " + f"params={json.dumps(current.dist_params, sort_keys=True)}" + ) @staticmethod def _parse_params(text: str) -> dict[str, float]: @@ -627,6 +1504,208 @@ def _normalize_distribution_params( params[key] = float(raw_params[key]) return params + @classmethod + def _center_drag_adjusted_entry( + cls, + entry: DreamParameterEntry, + *, + target_center: float, + ) -> DreamParameterEntry: + updated_entry = DreamParameterEntry.from_dict(entry.to_dict()) + params = copy.deepcopy(dict(updated_entry.dist_params)) + center_value = float(target_center) + updated_entry.value = center_value + if updated_entry.distribution == "norm": + params["loc"] = center_value + params["scale"] = max( + float(params.get("scale", INTERACTIVE_PARAMETER_EPSILON)), + INTERACTIVE_PARAMETER_EPSILON, + ) + elif updated_entry.distribution == "uniform": + width = max( + float(params.get("scale", INTERACTIVE_PARAMETER_EPSILON)), + INTERACTIVE_PARAMETER_EPSILON, + ) + params["scale"] = width + params["loc"] = center_value - width / 2.0 + elif updated_entry.distribution == "lognorm": + scale_value = max( + float(params.get("scale", INTERACTIVE_PARAMETER_EPSILON)), + INTERACTIVE_PARAMETER_EPSILON, + ) + params["scale"] = scale_value + params["s"] = max( + float(params.get("s", INTERACTIVE_PARAMETER_EPSILON)), + INTERACTIVE_PARAMETER_EPSILON, + ) + params["loc"] = center_value - scale_value + updated_entry.dist_params = params + return updated_entry + + @classmethod + def _width_drag_adjusted_entry( + cls, + entry: DreamParameterEntry, + *, + handle_kind: str, + target_x: float, + ) -> DreamParameterEntry: + updated_entry = DreamParameterEntry.from_dict(entry.to_dict()) + params = copy.deepcopy(dict(updated_entry.dist_params)) + center_value = float(updated_entry.value) + bounded_target = float(target_x) + if updated_entry.distribution == "norm": + sigma = max( + abs(bounded_target - center_value) / GUIDE_INTERVAL_SIGMA, + INTERACTIVE_PARAMETER_EPSILON, + ) + params["loc"] = center_value + params["scale"] = sigma + elif updated_entry.distribution == "uniform": + width = max( + 2.0 * abs(bounded_target - center_value), + INTERACTIVE_PARAMETER_EPSILON, + ) + params["scale"] = width + params["loc"] = center_value - width / 2.0 + elif updated_entry.distribution == "lognorm": + scale_value = max( + float(params.get("scale", INTERACTIVE_PARAMETER_EPSILON)), + INTERACTIVE_PARAMETER_EPSILON, + ) + if handle_kind == "left_width": + normalized = max( + 1.0 + + (min(bounded_target, center_value) - center_value) + / scale_value, + INTERACTIVE_PARAMETER_EPSILON, + ) + shape_value = max( + -math.log(normalized) / GUIDE_INTERVAL_SIGMA, + INTERACTIVE_PARAMETER_EPSILON, + ) + else: + normalized = max( + 1.0 + + (max(bounded_target, center_value) - center_value) + / scale_value, + 1.0 + INTERACTIVE_PARAMETER_EPSILON, + ) + shape_value = max( + math.log(normalized) / GUIDE_INTERVAL_SIGMA, + INTERACTIVE_PARAMETER_EPSILON, + ) + params["scale"] = scale_value + params["loc"] = center_value - scale_value + params["s"] = min( + shape_value, + INTERACTIVE_MAX_LOGNORM_SHAPE, + ) + updated_entry.dist_params = params + return updated_entry + + @classmethod + def _peak_drag_adjusted_entry( + cls, + entry: DreamParameterEntry, + *, + start_y: float, + target_y: float, + y_limits: tuple[float, float], + ) -> DreamParameterEntry: + updated_entry = DreamParameterEntry.from_dict(entry.to_dict()) + params = copy.deepcopy(dict(updated_entry.dist_params)) + span = max(float(y_limits[1]) - float(y_limits[0]), 1e-9) + delta_fraction = (float(target_y) - float(start_y)) / span + width_factor = math.exp( + -INTERACTIVE_PEAK_DRAG_SENSITIVITY * delta_fraction + ) + center_value = float(updated_entry.value) + if updated_entry.distribution == "norm": + sigma = max( + float(params.get("scale", INTERACTIVE_PARAMETER_EPSILON)) + * width_factor, + INTERACTIVE_PARAMETER_EPSILON, + ) + params["loc"] = center_value + params["scale"] = sigma + elif updated_entry.distribution == "uniform": + width = max( + float(params.get("scale", INTERACTIVE_PARAMETER_EPSILON)) + * width_factor, + INTERACTIVE_PARAMETER_EPSILON, + ) + params["scale"] = width + params["loc"] = center_value - width / 2.0 + elif updated_entry.distribution == "lognorm": + scale_value = max( + float(params.get("scale", INTERACTIVE_PARAMETER_EPSILON)), + INTERACTIVE_PARAMETER_EPSILON, + ) + params["scale"] = scale_value + params["loc"] = center_value - scale_value + params["s"] = min( + max( + float(params.get("s", INTERACTIVE_PARAMETER_EPSILON)) + * width_factor, + INTERACTIVE_PARAMETER_EPSILON, + ), + INTERACTIVE_MAX_LOGNORM_SHAPE, + ) + updated_entry.dist_params = params + return updated_entry + + @staticmethod + def _interactive_width_handle_positions( + entry: DreamParameterEntry, + ) -> tuple[float | None, float | None]: + center_value = float(entry.value) + if entry.distribution == "norm": + sigma = max( + float( + entry.dist_params.get( + "scale", INTERACTIVE_PARAMETER_EPSILON + ) + ), + INTERACTIVE_PARAMETER_EPSILON, + ) + spread = GUIDE_INTERVAL_SIGMA * sigma + return center_value - spread, center_value + spread + if entry.distribution == "uniform": + width = max( + float( + entry.dist_params.get( + "scale", INTERACTIVE_PARAMETER_EPSILON + ) + ), + INTERACTIVE_PARAMETER_EPSILON, + ) + return center_value - width / 2.0, center_value + width / 2.0 + if entry.distribution == "lognorm": + scale_value = max( + float( + entry.dist_params.get( + "scale", INTERACTIVE_PARAMETER_EPSILON + ) + ), + INTERACTIVE_PARAMETER_EPSILON, + ) + shape_value = max( + float( + entry.dist_params.get("s", INTERACTIVE_PARAMETER_EPSILON) + ), + INTERACTIVE_PARAMETER_EPSILON, + ) + return ( + center_value + + scale_value + * (math.exp(-GUIDE_INTERVAL_SIGMA * shape_value) - 1.0), + center_value + + scale_value + * (math.exp(GUIDE_INTERVAL_SIGMA * shape_value) - 1.0), + ) + return None, None + def _smart_adjusted_entries( self, entries: list[DreamParameterEntry], @@ -751,6 +1830,7 @@ def _on_smart_status_changed( self.load_entries( updated_entries, has_existing_parameter_map=self._has_existing_parameter_map, + update_reset_entries=False, ) self.console.append( "Applied row smart prior preset: " @@ -784,8 +1864,8 @@ def _target_group_keys( raise ValueError(f"Unknown smart prior apply scope: {apply_scope}") if not selected_rows: raise ValueError( - "Select one or more structure rows before applying a smart " - "prior preset to selected structures." + "Select one or more parameter rows before applying a smart " + "prior preset to selected parameters." ) return { row_groups[row] @@ -965,16 +2045,319 @@ def _is_effective_radius_parameter(param_name: str) -> bool: or name.startswith("c_eff_") ) + def _refresh_distribution_guides_for_row( + self, + row: int, + *, + entry: DreamParameterEntry | None, + ) -> None: + guide_low_text = "n/a" + guide_high_text = "n/a" + guide_tooltip = ( + "Practical prior bound is unavailable until the distribution " + "parameters are valid." + ) + if entry is not None: + guide_low, guide_high, guide_kind = _distribution_guide_bounds( + entry + ) + if guide_low is not None and guide_high is not None: + guide_low_text = _format_distribution_guide_value(guide_low) + guide_high_text = _format_distribution_guide_value(guide_high) + guide_tooltip = ( + f"{guide_kind} for the current {entry.distribution} prior." + ) + self._set_read_only_table_item( + row, + GUIDE_LOW_COLUMN, + guide_low_text, + tooltip=guide_tooltip, + ) + self._set_read_only_table_item( + row, + GUIDE_HIGH_COLUMN, + guide_high_text, + tooltip=guide_tooltip, + ) + + def _set_read_only_table_item( + self, + row: int, + column: int, + text: str, + *, + tooltip: str, + ) -> None: + was_blocked = self.table.blockSignals(True) + try: + item = self.table.item(row, column) + if item is None: + item = QTableWidgetItem() + item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable) + item.setTextAlignment( + Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter + ) + self.table.setItem(row, column, item) + item.setText(text) + item.setToolTip(tooltip) + finally: + self.table.blockSignals(was_blocked) + + +def _x_coordinate_is_valid_for_scale( + x_value: float | None, + x_scale: str, +) -> bool: + if x_value is None or not np.isfinite(x_value): + return False + if x_scale == "log" and float(x_value) <= 0.0: + return False + return True + def _distribution_domain(entry: DreamParameterEntry) -> tuple[float, float]: + return _distribution_domain_quantiles( + entry, + lower_q=0.001, + upper_q=0.999, + ) + + +def _distribution_domain_quantiles( + entry: DreamParameterEntry, + *, + lower_q: float, + upper_q: float, +) -> tuple[float, float]: distribution = getattr(stats, entry.distribution) - x_min = distribution.ppf(0.001, **entry.dist_params) - x_max = distribution.ppf(0.999, **entry.dist_params) + x_min = distribution.ppf(lower_q, **entry.dist_params) + x_max = distribution.ppf(upper_q, **entry.dist_params) if not np.isfinite(x_min) or not np.isfinite(x_max) or x_min == x_max: return entry.value - 1.0, entry.value + 1.0 return float(x_min), float(x_max) +def _distribution_plot_bounds( + entry: DreamParameterEntry, +) -> tuple[tuple[float, float], tuple[float, float], str]: + distribution = getattr(stats, entry.distribution) + sample_low, sample_high = _distribution_domain_quantiles( + entry, + lower_q=PLOT_DOMAIN_LOWER_Q, + upper_q=PLOT_DOMAIN_UPPER_Q, + ) + if not np.isfinite(sample_low) or not np.isfinite(sample_high): + raise ValueError("Plot domain is not finite.") + if sample_low == sample_high: + span = max(abs(float(entry.value)) * 0.5, 1.0) + sample_low = float(entry.value) - span + sample_high = float(entry.value) + span + if sample_high < sample_low: + sample_low, sample_high = sample_high, sample_low + + use_log_scale = ( + entry.distribution == "lognorm" + and sample_low > 0.0 + and sample_high / sample_low >= PLOT_LOG_SCALE_SPAN_RATIO + ) + sample_x = _distribution_sample_grid( + sample_low, + sample_high, + count=PLOT_SAMPLE_COUNT, + x_scale="log" if use_log_scale else "linear", + ) + sample_y = distribution.pdf(sample_x, **entry.dist_params) + finite_mask = np.isfinite(sample_x) & np.isfinite(sample_y) + if not np.any(finite_mask): + raise ValueError("Distribution density is not finite.") + sample_x = sample_x[finite_mask] + sample_y = sample_y[finite_mask] + peak_density = float(np.max(sample_y)) + if not np.isfinite(peak_density) or peak_density <= 0.0: + raise ValueError("Distribution density could not be evaluated.") + + focus_mask = sample_y >= peak_density * PLOT_RELATIVE_DENSITY_THRESHOLD + if not np.any(focus_mask): + focus_mask = np.ones_like(sample_y, dtype=bool) + focus_low = float(sample_x[focus_mask][0]) + focus_high = float(sample_x[focus_mask][-1]) + if np.isfinite(entry.value): + focus_low = min(focus_low, float(entry.value)) + focus_high = max(focus_high, float(entry.value)) + + focus_low, focus_high = _expand_plot_limits( + focus_low, + focus_high, + x_scale="log" if use_log_scale else "linear", + padding_fraction=PLOT_PADDING_FRACTION, + ) + focus_x = _distribution_sample_grid( + focus_low, + focus_high, + count=PLOT_SAMPLE_COUNT, + x_scale="log" if use_log_scale else "linear", + ) + focus_y = distribution.pdf(focus_x, **entry.dist_params) + finite_focus_mask = np.isfinite(focus_x) & np.isfinite(focus_y) + if not np.any(finite_focus_mask): + raise ValueError("Focused plot density is not finite.") + focus_x = focus_x[finite_focus_mask] + focus_y = focus_y[finite_focus_mask] + y_peak = float(np.max(focus_y)) + if not np.isfinite(y_peak) or y_peak <= 0.0: + raise ValueError("Focused plot density could not be evaluated.") + y_padding = max(y_peak * PLOT_PADDING_FRACTION, 1e-12) + return ( + (float(focus_x[0]), float(focus_x[-1])), + (0.0, max(y_peak + y_padding, 1e-12)), + "log" if use_log_scale else "linear", + ) + + +def _distribution_plot_window( + entry: DreamParameterEntry, +) -> tuple[tuple[float, float], tuple[float, float], str]: + required_x_limits, required_y_limits, x_scale = _distribution_plot_bounds( + entry + ) + window_low, window_high = _expand_plot_limits( + required_x_limits[0], + required_x_limits[1], + x_scale=x_scale, + padding_fraction=PLOT_WINDOW_MARGIN_FRACTION, + ) + return ( + (window_low, window_high), + ( + required_y_limits[0], + max(required_y_limits[1], 1e-12) + * (1.0 + PLOT_WINDOW_MARGIN_FRACTION), + ), + x_scale, + ) + + +def _distribution_plot_curve( + entry: DreamParameterEntry, + *, + x_limits: tuple[float, float], + x_scale: str, +) -> tuple[np.ndarray, np.ndarray]: + distribution = getattr(stats, entry.distribution) + plot_x = _distribution_sample_grid( + x_limits[0], + x_limits[1], + count=PLOT_SAMPLE_COUNT, + x_scale=x_scale, + ) + plot_y = distribution.pdf(plot_x, **entry.dist_params) + finite_plot_mask = np.isfinite(plot_x) & np.isfinite(plot_y) + if not np.any(finite_plot_mask): + raise ValueError("Focused plot density is not finite.") + return plot_x[finite_plot_mask], plot_y[finite_plot_mask] + + +def _distribution_sample_grid( + lower: float, + upper: float, + *, + count: int, + x_scale: str, +) -> np.ndarray: + if x_scale == "log": + bounded_low = max(float(lower), np.finfo(float).tiny) + bounded_high = max(float(upper), bounded_low * (1.0 + 1e-9)) + return np.geomspace(bounded_low, bounded_high, count) + bounded_low = float(lower) + bounded_high = float(upper) + if bounded_high == bounded_low: + bounded_high = bounded_low + 1.0 + return np.linspace(bounded_low, bounded_high, count) + + +def _expand_plot_limits( + lower: float, + upper: float, + *, + x_scale: str, + padding_fraction: float = PLOT_PADDING_FRACTION, +) -> tuple[float, float]: + if x_scale == "log": + bounded_low = max(float(lower), np.finfo(float).tiny) + bounded_high = max(float(upper), bounded_low * (1.0 + 1e-9)) + ratio = bounded_high / bounded_low + padding_ratio = max(ratio**padding_fraction, 1.05) + return bounded_low / padding_ratio, bounded_high * padding_ratio + bounded_low = float(lower) + bounded_high = float(upper) + span = bounded_high - bounded_low + if not np.isfinite(span) or span <= 0.0: + span = max(abs(bounded_low), abs(bounded_high), 1.0) + padding = max(span * padding_fraction, 1e-9) + return bounded_low - padding, bounded_high + padding + + +def _distribution_guide_bounds( + entry: DreamParameterEntry, +) -> tuple[float | None, float | None, str]: + distribution = getattr(stats, entry.distribution) + try: + support_low, support_high = distribution.support(**entry.dist_params) + support_low = float(support_low) + support_high = float(support_high) + except Exception: + support_low = float("nan") + support_high = float("nan") + + if np.isfinite(support_low) and np.isfinite(support_high): + if support_low <= support_high: + return support_low, support_high, "Exact support" + + try: + guide_low = float( + distribution.ppf(GUIDE_INTERVAL_LOWER_Q, **entry.dist_params) + ) + guide_high = float( + distribution.ppf(GUIDE_INTERVAL_UPPER_Q, **entry.dist_params) + ) + if np.isfinite(support_low): + guide_low = ( + max(guide_low, support_low) + if np.isfinite(guide_low) + else support_low + ) + if np.isfinite(support_high): + guide_high = ( + min(guide_high, support_high) + if np.isfinite(guide_high) + else support_high + ) + if np.isfinite(guide_low) and np.isfinite(guide_high): + if guide_low <= guide_high: + return ( + guide_low, + guide_high, + "Central 99.73% interval (3sigma equivalent)", + ) + except Exception: + pass + + try: + domain_low, domain_high = _distribution_domain(entry) + except Exception: + return None, None, "Unavailable" + if np.isfinite(domain_low) and np.isfinite(domain_high): + if domain_low <= domain_high: + return domain_low, domain_high, "Preview domain fallback" + return None, None, "Unavailable" + + +def _format_distribution_guide_value(value: float) -> str: + if not np.isfinite(value): + return "n/a" + return f"{float(value):.6g}" + + def _distribution_defaults_for_value( distribution: str, value: float, diff --git a/src/saxshell/saxs/ui/dream_tab.py b/src/saxshell/saxs/ui/dream_tab.py index 5a0c812..5299d9b 100644 --- a/src/saxshell/saxs/ui/dream_tab.py +++ b/src/saxshell/saxs/ui/dream_tab.py @@ -9,7 +9,7 @@ from matplotlib.colors import to_hex from matplotlib.figure import Figure from PySide6.QtCore import Qt, QTimer, Signal -from PySide6.QtGui import QColor, QTextOption, QValidator +from PySide6.QtGui import QColor, QTextCursor, QTextOption, QValidator from PySide6.QtWidgets import ( QAbstractItemView, QAbstractScrollArea, @@ -120,6 +120,8 @@ def validate( class DreamTab(QWidget): ACTIVE_SETTINGS_LABEL = "Active project settings" NO_SAVED_RUNS_LABEL = "No saved DREAM runs" + RUNTIME_OUTPUT_FLUSH_INTERVAL_MS = 300 + MAX_VIOLIN_PLOT_SAMPLES = 4096 edit_parameter_map_requested = Signal() save_settings_requested = Signal() @@ -128,26 +130,45 @@ class DreamTab(QWidget): run_dream_requested = Signal() load_results_requested = Signal() save_report_requested = Signal() + recycle_output_requested = Signal() + export_model_report_requested = Signal() save_model_fit_requested = Signal() save_violin_data_requested = Signal() settings_preset_changed = Signal(str) visualization_settings_changed = Signal() + results_settings_changed = Signal() + summary_settings_changed = Signal() + violin_data_settings_changed = Signal() + violin_style_settings_changed = Signal() def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) + self._console_autoscroll_enabled = True self._summary_text = "" self._base_log_text = "" self._history_messages: list[str] = [] self._live_output_history_index: int | None = None + self._pending_runtime_output_lines: list[str] = [] self._applying_search_filter_preset = False self._current_model_plot_data: DreamModelPlotData | None = None + self._current_summary: DreamSummary | None = None + self._current_violin_plot_data: DreamViolinPlotData | None = None self._model_legend_line_map: dict[object, object] = {} self._model_legend_handle_lookup: dict[str, object] = {} + self._suspend_visualization_notifications = False self._blink_timer = QTimer(self) self._blink_timer.setInterval(180) self._blink_timer.timeout.connect(self._advance_button_blink) self._blink_remaining = 0 self._blink_target_button: QPushButton | None = None + self._runtime_output_flush_timer = QTimer(self) + self._runtime_output_flush_timer.setSingleShot(True) + self._runtime_output_flush_timer.setInterval( + self.RUNTIME_OUTPUT_FLUSH_INTERVAL_MS + ) + self._runtime_output_flush_timer.timeout.connect( + self._flush_pending_runtime_output + ) self._build_ui() self.set_settings(DreamRunSettings(), preset_name=None) @@ -521,7 +542,7 @@ def _build_settings_group(self) -> QGroupBox: self.bestfit_method_combo.addItem("Chain Mean MAP", "chain_mean") self.bestfit_method_combo.addItem("Median", "median") self.bestfit_method_combo.currentIndexChanged.connect( - lambda _index: self.visualization_settings_changed.emit() + self._on_bestfit_method_changed ) self.violin_mode_combo = QComboBox() self.violin_mode_combo.addItem( @@ -530,8 +551,16 @@ def _build_settings_group(self) -> QGroupBox: self.violin_mode_combo.addItem("All Parameters", "all_parameters") self.violin_mode_combo.addItem("Weights Only", "weights_only") self.violin_mode_combo.addItem("Fit Parameters", "fit_parameters") + self.violin_mode_combo.addItem( + "Effective Radii Only", + "effective_radii_only", + ) + self.violin_mode_combo.addItem( + "Additional Parameters Only", + "additional_parameters_only", + ) self.violin_mode_combo.currentIndexChanged.connect( - lambda _index: self.visualization_settings_changed.emit() + self._on_violin_mode_changed ) bestfit_tip = ( "Choose how posterior samples are reduced to one best-fit " @@ -545,7 +574,8 @@ def _build_settings_group(self) -> QGroupBox: ) violin_tip = ( "Choose which subset of posterior parameters is shown in the " - "violin plot." + "violin plot, including weight-only, effective-radius-only, " + "and additional-parameter-only views." ) violin_label = QLabel("Violin data") self._set_widget_tooltip( @@ -563,7 +593,7 @@ def _build_settings_group(self) -> QGroupBox: self.weight_order_combo.addItem("Weight Index", "weight_index") self.weight_order_combo.addItem("Structure Order", "structure_order") self.weight_order_combo.currentIndexChanged.connect( - lambda _index: self.visualization_settings_changed.emit() + self._on_weight_order_changed ) weight_order_tip = ( "Choose whether weight parameters stay in their original w-index " @@ -590,13 +620,22 @@ def _build_settings_group(self) -> QGroupBox: self.violin_value_scale_combo.addItem( "Normalized 0-1 (All)", "normalized_all" ) + self.violin_value_scale_combo.addItem( + "Effective Radii Only", + "effective_radii_only", + ) + self.violin_value_scale_combo.addItem( + "Additional Parameters Only", + "additional_parameters_only", + ) self.violin_value_scale_combo.currentIndexChanged.connect( - lambda _index: self.visualization_settings_changed.emit() + self._on_violin_value_scale_changed ) value_scale_tip = ( "Choose whether the posterior violin plot uses native parameter " - "values, only the weight parameters on a 0-1 fraction scale, or " - "all parameters normalized independently onto a 0-1 axis." + "values, only the weight parameters on a 0-1 fraction scale, " + "all parameters normalized independently onto a 0-1 axis, or " + "effective-radius-only / additional-parameter-only views." ) value_scale_label = QLabel("Y-axis scale") self._set_widget_tooltip( @@ -747,7 +786,7 @@ def _build_settings_group(self) -> QGroupBox: self.violin_outline_width_spin.setSingleStep(0.1) self.violin_outline_width_spin.setValue(0.8) self.violin_outline_width_spin.valueChanged.connect( - lambda _value: self.visualization_settings_changed.emit() + self._on_violin_outline_width_changed ) outline_width_label = QLabel("Width") self._set_widget_tooltip( @@ -832,6 +871,24 @@ def _build_settings_group(self) -> QGroupBox: "Save a text report of the currently loaded DREAM posterior summary." ) self.report_button.clicked.connect(self.save_report_requested.emit) + self.recycle_button = QPushButton("Recycle") + self.recycle_button.setToolTip( + "Copy the currently selected DREAM best-fit parameter values into " + "the Prefit parameter table so you can manually refine them " + "before running DREAM again." + ) + self.recycle_button.clicked.connect(self.recycle_output_requested.emit) + self.export_model_report_button = QPushButton( + "Export Model Report (PPTX)" + ) + self.export_model_report_button.setToolTip( + "Create a multi-slide PowerPoint report for the current DREAM " + "fit, including project information, prior histograms, prefit " + "state, posterior filtering comparisons, and output paths." + ) + self.export_model_report_button.clicked.connect( + self.export_model_report_requested.emit + ) row += 1 action_sections = QWidget() @@ -854,6 +911,14 @@ def _build_settings_group(self) -> QGroupBox: analysis_layout.addWidget(self.saved_runs_combo, 0, 1, 1, 2) analysis_layout.addWidget(self.load_button, 1, 0) analysis_layout.addWidget(self.report_button, 1, 1) + analysis_layout.addWidget(self.recycle_button, 1, 2) + analysis_layout.addWidget( + self.export_model_report_button, + 2, + 0, + 1, + 3, + ) action_sections_layout.addWidget(self.setup_actions_group) action_sections_layout.addWidget(self.analysis_actions_group) layout.addWidget(action_sections, row, 0, 1, 4) @@ -904,7 +969,7 @@ def _build_posterior_filter_group(self) -> QGroupBox: lambda _index: self._mark_search_filter_preset_custom() ) self.violin_sample_source_combo.currentIndexChanged.connect( - lambda _index: self.visualization_settings_changed.emit() + self._on_violin_sample_source_changed ) violin_source_tip = ( "Choose whether the violin plot shows the full filtered " @@ -931,7 +996,7 @@ def _build_posterior_filter_group(self) -> QGroupBox: lambda _value: self._mark_search_filter_preset_custom() ) self.posterior_top_percent_spin.valueChanged.connect( - lambda _value: self.visualization_settings_changed.emit() + self._on_posterior_top_percent_changed ) top_percent_tip = ( "Default percent used whenever Top % log-posterior screening is " @@ -952,7 +1017,7 @@ def _build_posterior_filter_group(self) -> QGroupBox: lambda _value: self._mark_search_filter_preset_custom() ) self.posterior_top_n_spin.valueChanged.connect( - lambda _value: self.visualization_settings_changed.emit() + self._on_posterior_top_n_changed ) top_n_tip = ( "Default count used whenever Top N log-posterior screening is " @@ -988,7 +1053,7 @@ def _build_posterior_filter_group(self) -> QGroupBox: self.credible_interval_low_spin.setSingleStep(1.0) self.credible_interval_low_spin.setValue(16.0) self.credible_interval_low_spin.valueChanged.connect( - lambda _value: self.visualization_settings_changed.emit() + self._on_credible_interval_low_changed ) interval_low_tip = ( "Lower percentile used for the posterior interval bars and " @@ -1007,7 +1072,7 @@ def _build_posterior_filter_group(self) -> QGroupBox: self.credible_interval_high_spin.setSingleStep(1.0) self.credible_interval_high_spin.setValue(84.0) self.credible_interval_high_spin.valueChanged.connect( - lambda _value: self.visualization_settings_changed.emit() + self._on_credible_interval_high_changed ) interval_high_tip = ( "Upper percentile used for the posterior interval bars and " @@ -1096,6 +1161,13 @@ def _build_model_plot_group(self) -> QGroupBox: self.show_solvent_trace_checkbox.toggled.connect( self._redraw_current_model_plot ) + self.show_structure_factor_trace_checkbox = QCheckBox( + "Structure factor" + ) + self.show_structure_factor_trace_checkbox.setChecked(False) + self.show_structure_factor_trace_checkbox.toggled.connect( + self._redraw_current_model_plot + ) self.model_log_x_checkbox = QCheckBox("Log X") self.model_log_x_checkbox.setChecked(True) self.model_log_x_checkbox.toggled.connect( @@ -1109,6 +1181,7 @@ def _build_model_plot_group(self) -> QGroupBox: controls.addWidget(self.show_experimental_trace_checkbox) controls.addWidget(self.show_model_trace_checkbox) controls.addWidget(self.show_solvent_trace_checkbox) + controls.addWidget(self.show_structure_factor_trace_checkbox) controls.addWidget(self.model_log_x_checkbox) controls.addWidget(self.model_log_y_checkbox) controls.addStretch(1) @@ -1269,108 +1342,112 @@ def set_settings( *, preset_name: str | None = None, ) -> None: + self._suspend_visualization_notifications = True self._applying_search_filter_preset = True - self.settings_preset_combo.blockSignals(True) - self.settings_preset_combo.setCurrentText( - preset_name or self.ACTIVE_SETTINGS_LABEL - ) - self.settings_preset_combo.blockSignals(False) - self.model_name_edit.setText(settings.model_name or "") - self.chains_spin.setValue(settings.nchains) - self.iterations_spin.setValue(settings.niterations) - self.burnin_spin.setValue(settings.burnin_percent) - self.history_thin_spin.setValue(settings.history_thin) - self.nseedchains_spin.setValue(settings.nseedchains) - self.crossover_burnin_spin.setValue(settings.crossover_burnin) - self.lambda_spin.setValue(settings.lamb) - self.zeta_spin.setValue(settings.zeta) - self.snooker_spin.setValue(settings.snooker) - self.p_gamma_unity_spin.setValue(settings.p_gamma_unity) - self.history_file_edit.setText(settings.history_file or "") - self.verbose_checkbox.setChecked(settings.verbose) - self.parallel_checkbox.setChecked(settings.parallel) - self.adapt_checkbox.setChecked(settings.adapt_crossover) - self.restart_checkbox.setChecked(settings.restart) - self._set_combo_data( - self.search_filter_preset_combo, - settings.search_filter_preset or "custom", - ) - self._set_combo_data( - self.bestfit_method_combo, settings.bestfit_method - ) - self.verbose_interval_spin.setValue( - settings.verbose_output_interval_seconds - ) - self._set_combo_data( - self.posterior_filter_combo, - settings.posterior_filter_mode, - ) - self.posterior_top_percent_spin.setValue( - settings.posterior_top_percent - ) - self.posterior_top_n_spin.setValue(settings.posterior_top_n) - self.auto_filter_assessment_checkbox.setChecked( - settings.auto_select_best_posterior_filter - ) - self.credible_interval_low_spin.setValue( - settings.credible_interval_low - ) - self.credible_interval_high_spin.setValue( - settings.credible_interval_high - ) - self._set_combo_data( - self.violin_mode_combo, - settings.violin_parameter_mode, - ) - self._set_combo_data( - self.violin_sample_source_combo, - settings.violin_sample_source, - ) - self._set_combo_data( - self.weight_order_combo, - settings.violin_weight_order, - ) - self._set_combo_data( - self.violin_value_scale_combo, - settings.violin_value_scale_mode, - ) - self._configure_plot_color_button( - self.violin_custom_color_button, - settings.violin_custom_color, - label="Custom violin", - ) - if not self._set_combo_data( - self.violin_palette_combo, - settings.violin_palette, - ): + try: + self.settings_preset_combo.blockSignals(True) + self.settings_preset_combo.setCurrentText( + preset_name or self.ACTIVE_SETTINGS_LABEL + ) + self.settings_preset_combo.blockSignals(False) + self.model_name_edit.setText(settings.model_name or "") + self.chains_spin.setValue(settings.nchains) + self.iterations_spin.setValue(settings.niterations) + self.burnin_spin.setValue(settings.burnin_percent) + self.history_thin_spin.setValue(settings.history_thin) + self.nseedchains_spin.setValue(settings.nseedchains) + self.crossover_burnin_spin.setValue(settings.crossover_burnin) + self.lambda_spin.setValue(settings.lamb) + self.zeta_spin.setValue(settings.zeta) + self.snooker_spin.setValue(settings.snooker) + self.p_gamma_unity_spin.setValue(settings.p_gamma_unity) + self.history_file_edit.setText(settings.history_file or "") + self.verbose_checkbox.setChecked(settings.verbose) + self.parallel_checkbox.setChecked(settings.parallel) + self.adapt_checkbox.setChecked(settings.adapt_crossover) + self.restart_checkbox.setChecked(settings.restart) + self._set_combo_data( + self.search_filter_preset_combo, + settings.search_filter_preset or "custom", + ) + self._set_combo_data( + self.bestfit_method_combo, settings.bestfit_method + ) + self.verbose_interval_spin.setValue( + settings.verbose_output_interval_seconds + ) + self._set_combo_data( + self.posterior_filter_combo, + settings.posterior_filter_mode, + ) + self.posterior_top_percent_spin.setValue( + settings.posterior_top_percent + ) + self.posterior_top_n_spin.setValue(settings.posterior_top_n) + self.auto_filter_assessment_checkbox.setChecked( + settings.auto_select_best_posterior_filter + ) + self.credible_interval_low_spin.setValue( + settings.credible_interval_low + ) + self.credible_interval_high_spin.setValue( + settings.credible_interval_high + ) self._set_combo_data( + self.violin_mode_combo, + settings.violin_parameter_mode, + ) + self._set_combo_data( + self.violin_sample_source_combo, + settings.violin_sample_source, + ) + self._set_combo_data( + self.weight_order_combo, + settings.violin_weight_order, + ) + self._set_combo_data( + self.violin_value_scale_combo, + settings.violin_value_scale_mode, + ) + self._configure_plot_color_button( + self.violin_custom_color_button, + settings.violin_custom_color, + label="Custom violin", + ) + if not self._set_combo_data( self.violin_palette_combo, - "custom_solid", + settings.violin_palette, + ): + self._set_combo_data( + self.violin_palette_combo, + "custom_solid", + ) + self._configure_plot_color_button( + self.violin_point_color_button, + settings.violin_point_color, + label="Point", ) - self._configure_plot_color_button( - self.violin_point_color_button, - settings.violin_point_color, - label="Point", - ) - self._configure_plot_color_button( - self.interval_color_button, - settings.violin_interval_color, - label="Interval", - ) - self._configure_plot_color_button( - self.median_color_button, - settings.violin_median_color, - label="Median", - ) - self._configure_plot_color_button( - self.violin_outline_color_button, - settings.violin_outline_color, - label="Outline", - ) - self.violin_outline_width_spin.setValue( - float(settings.violin_outline_width) - ) - self._applying_search_filter_preset = False + self._configure_plot_color_button( + self.interval_color_button, + settings.violin_interval_color, + label="Interval", + ) + self._configure_plot_color_button( + self.median_color_button, + settings.violin_median_color, + label="Median", + ) + self._configure_plot_color_button( + self.violin_outline_color_button, + settings.violin_outline_color, + label="Outline", + ) + self.violin_outline_width_spin.setValue( + float(settings.violin_outline_width) + ) + finally: + self._applying_search_filter_preset = False + self._suspend_visualization_notifications = False self._update_violin_style_controls() self._update_posterior_filter_controls() self._update_verbose_output_controls() @@ -1539,6 +1616,7 @@ def _update_verbose_output_controls(self) -> None: ) def append_log(self, message: str) -> None: + self._flush_pending_runtime_output() stripped = message.strip() if stripped: self._history_messages.append(stripped) @@ -1546,6 +1624,8 @@ def append_log(self, message: str) -> None: self._render_output(scroll_to_end=True) def set_log_text(self, text: str) -> None: + self._runtime_output_flush_timer.stop() + self._pending_runtime_output_lines = [] self._base_log_text = text.strip() self._history_messages = [] self._live_output_history_index = None @@ -1555,20 +1635,36 @@ def set_summary_text(self, text: str) -> None: self._summary_text = text.strip() self._render_output() + def set_console_autoscroll_enabled(self, enabled: bool) -> None: + self._console_autoscroll_enabled = bool(enabled) + if self._console_autoscroll_enabled: + self._scroll_output_to_end() + def append_runtime_output(self, message: str) -> None: stripped = message.rstrip() if not stripped: return + self._pending_runtime_output_lines.append(stripped) + if not self._runtime_output_flush_timer.isActive(): + self._runtime_output_flush_timer.start() + + def _flush_pending_runtime_output(self) -> None: + if not self._pending_runtime_output_lines: + return + chunk = "\n".join(self._pending_runtime_output_lines) + self._pending_runtime_output_lines = [] if self._live_output_history_index is None: - self._history_messages.append("DREAM Runtime Output\n" + stripped) + self._history_messages.append("DREAM Runtime Output\n" + chunk) self._live_output_history_index = len(self._history_messages) - 1 else: self._history_messages[self._live_output_history_index] += ( - "\n" + stripped + "\n" + chunk ) self._render_output(scroll_to_end=True) def finish_runtime_output(self) -> None: + self._runtime_output_flush_timer.stop() + self._flush_pending_runtime_output() self._live_output_history_index = None def clear_plots(self) -> None: @@ -1581,11 +1677,47 @@ def start_progress(self, message: str) -> None: self.progress_bar.setValue(0) self.progress_bar.setFormat("") - def finish_progress(self, message: str) -> None: + def begin_progress( + self, + total: int, + message: str, + *, + unit_label: str = "steps", + ) -> None: + total = max(int(total), 1) self.progress_label.setText(message) - self.progress_bar.setRange(0, 1) - self.progress_bar.setValue(1) - self.progress_bar.setFormat("%v / %m runs") + self.progress_bar.setRange(0, total) + self.progress_bar.setValue(0) + self.progress_bar.setFormat(f"%v / %m {unit_label}") + + def update_progress( + self, + processed: int, + total: int, + message: str, + *, + unit_label: str = "steps", + ) -> None: + total = max(int(total), 1) + processed = max(0, min(int(processed), total)) + self.progress_label.setText(message) + self.progress_bar.setRange(0, total) + self.progress_bar.setValue(processed) + self.progress_bar.setFormat(f"%v / %m {unit_label}") + + def finish_progress( + self, + message: str, + *, + total: int | None = None, + unit_label: str = "runs", + ) -> None: + if total is None: + total = 1 + self.progress_label.setText(message) + self.progress_bar.setRange(0, total) + self.progress_bar.setValue(total) + self.progress_bar.setFormat(f"%v / %m {unit_label}") def reset_progress(self) -> None: self.progress_label.setText("Progress: idle") @@ -1612,7 +1744,7 @@ def plot_model_fit(self, plot_data: DreamModelPlotData | None) -> None: va="center", ) axis.set_axis_off() - self.model_canvas.draw() + self.model_canvas.draw_idle() return grid = self.model_figure.add_gridspec(2, 1, height_ratios=[3, 1]) @@ -1622,6 +1754,7 @@ def plot_model_fit(self, plot_data: DreamModelPlotData | None) -> None: ) plotted_lines: list[object] = [] + structure_axis = None if self.show_experimental_trace_checkbox.isChecked(): experimental_artist = top_axis.scatter( plot_data.q_values, @@ -1652,6 +1785,36 @@ def plot_model_fit(self, plot_data: DreamModelPlotData | None) -> None: label="Solvent contribution", ) plotted_lines.append(solvent_line) + if ( + self.show_structure_factor_trace_checkbox.isChecked() + and plot_data.structure_factor_trace is not None + ): + structure_values = np.asarray( + plot_data.structure_factor_trace, + dtype=float, + ) + structure_mask = np.isfinite(structure_values) + if np.any(structure_mask): + structure_axis = top_axis.twinx() + structure_axis.set_xscale( + "log" + if self.model_log_x_checkbox.isChecked() + else "linear" + ) + (structure_line,) = structure_axis.plot( + np.asarray(plot_data.q_values, dtype=float)[ + structure_mask + ], + structure_values[structure_mask], + color="tab:purple", + linestyle="--", + linewidth=1.5, + label="Structure factor S(q)", + ) + structure_axis.set_ylabel("S(q)", color="tab:purple") + structure_axis.tick_params(axis="y", colors="tab:purple") + structure_axis.spines["right"].set_color("tab:purple") + plotted_lines.append(structure_line) if self.show_model_trace_checkbox.isChecked(): (model_line,) = top_axis.plot( plot_data.q_values, @@ -1707,7 +1870,7 @@ def plot_model_fit(self, plot_data: DreamModelPlotData | None) -> None: if plotted_lines: self._build_interactive_model_legend(top_axis, plotted_lines) self.model_figure.tight_layout() - self.model_canvas.draw() + self.model_canvas.draw_idle() def _redraw_current_model_plot(self) -> None: self.plot_model_fit(self._current_model_plot_data) @@ -1720,14 +1883,20 @@ def _update_model_trace_toggle_state( has_solvent = ( has_plot_data and plot_data.solvent_contribution is not None ) + has_structure_factor = ( + has_plot_data and plot_data.structure_factor_trace is not None + ) self.show_experimental_trace_checkbox.setEnabled(has_plot_data) self.show_model_trace_checkbox.setEnabled(has_plot_data) self.show_solvent_trace_checkbox.setEnabled(bool(has_solvent)) + self.show_structure_factor_trace_checkbox.setEnabled( + bool(has_structure_factor) + ) def _build_interactive_model_legend( self, axis, lines: list[object] ) -> None: - legend = axis.legend(loc="best") + legend = axis.legend(handles=lines, loc="best") if legend is None: return legend_handles = getattr(legend, "legend_handles", None) @@ -1766,6 +1935,8 @@ def plot_violin_plot( summary: DreamSummary | None, violin_data: DreamViolinPlotData | None, ) -> None: + self._current_summary = summary + self._current_violin_plot_data = violin_data self.violin_figure.clear() axis = self.violin_figure.add_subplot(111) if ( @@ -1782,13 +1953,14 @@ def plot_violin_plot( va="center", ) axis.set_axis_off() - self.violin_canvas.draw() + self.violin_canvas.draw_idle() return payload = self.prepare_violin_plot_payload(summary, violin_data) + display_samples = self._display_violin_samples(payload["samples"]) positions = np.arange(1, len(payload["display_names"]) + 1) violin_parts = axis.violinplot( - payload["samples"], + display_samples, positions=positions, showmedians=True, ) @@ -1842,7 +2014,19 @@ def plot_violin_plot( axis.set_ylim(*payload["y_limits"]) axis.legend(loc="best") self.violin_figure.tight_layout() - self.violin_canvas.draw() + self.violin_canvas.draw_idle() + + def redraw_current_violin_plot(self) -> None: + self.plot_violin_plot( + self._current_summary, + self._current_violin_plot_data, + ) + + def current_summary(self) -> DreamSummary | None: + return self._current_summary + + def current_violin_plot_data(self) -> DreamViolinPlotData | None: + return self._current_violin_plot_data def prepare_violin_plot_payload( self, @@ -2028,6 +2212,7 @@ def _set_button_blink_highlight( button.setGraphicsEffect(None) def _render_output(self, *, scroll_to_end: bool = False) -> None: + del scroll_to_end sections: list[str] = [] if self._summary_text: sections.append("DREAM Summary\n" + self._summary_text) @@ -2038,10 +2223,39 @@ def _render_output(self, *, scroll_to_end: bool = False) -> None: ] if history_parts: sections.append("DREAM Console\n" + "\n\n".join(history_parts)) + scrollbar = self.output_box.verticalScrollBar() + previous_value = scrollbar.value() + previous_maximum = max(scrollbar.maximum(), 1) self.output_box.setPlainText("\n\n".join(sections).strip()) - if scroll_to_end: - scrollbar = self.output_box.verticalScrollBar() - scrollbar.setValue(scrollbar.maximum()) + if self._console_autoscroll_enabled: + self._scroll_output_to_end() + return + updated_scrollbar = self.output_box.verticalScrollBar() + if updated_scrollbar.maximum() > 0: + position_fraction = previous_value / previous_maximum + updated_scrollbar.setValue( + int(round(position_fraction * updated_scrollbar.maximum())) + ) + + def _scroll_output_to_end(self) -> None: + cursor = self.output_box.textCursor() + cursor.movePosition(QTextCursor.MoveOperation.End) + self.output_box.setTextCursor(cursor) + self.output_box.ensureCursorVisible() + scrollbar = self.output_box.verticalScrollBar() + scrollbar.setValue(scrollbar.maximum()) + QTimer.singleShot( + 0, + self._scroll_output_to_end_once, + ) + + def _scroll_output_to_end_once(self) -> None: + cursor = self.output_box.textCursor() + cursor.movePosition(QTextCursor.MoveOperation.End) + self.output_box.setTextCursor(cursor) + self.output_box.ensureCursorVisible() + scrollbar = self.output_box.verticalScrollBar() + scrollbar.setValue(scrollbar.maximum()) def _browse_history_file(self) -> None: selected, _filter = QFileDialog.getOpenFileName( @@ -2055,7 +2269,7 @@ def _browse_history_file(self) -> None: def _on_violin_palette_changed(self, _index: int) -> None: self._update_violin_style_controls() - self.visualization_settings_changed.emit() + self._notify_violin_style_settings_changed() def _update_violin_style_controls(self) -> None: self.violin_custom_color_button.setEnabled(True) @@ -2126,12 +2340,12 @@ def _choose_plot_color( chosen.name(), label=label, ) - self.visualization_settings_changed.emit() + self._notify_violin_style_settings_changed() def _on_posterior_filter_mode_changed(self, _index: int) -> None: self._mark_search_filter_preset_custom() self._update_posterior_filter_controls() - self.visualization_settings_changed.emit() + self._notify_results_settings_changed() def _on_search_filter_preset_changed(self, _index: int) -> None: if self._applying_search_filter_preset: @@ -2146,6 +2360,7 @@ def _apply_search_filter_preset(self, preset_name: str) -> None: if preset is None: return self._applying_search_filter_preset = True + self._suspend_visualization_notifications = True try: self.chains_spin.setValue(int(preset["nchains"])) self.iterations_spin.setValue(int(preset["niterations"])) @@ -2170,7 +2385,8 @@ def _apply_search_filter_preset(self, preset_name: str) -> None: self._update_posterior_filter_controls() finally: self._applying_search_filter_preset = False - self.visualization_settings_changed.emit() + self._suspend_visualization_notifications = False + self._notify_results_settings_changed() def _mark_search_filter_preset_custom(self) -> None: if self._applying_search_filter_preset: @@ -2185,6 +2401,76 @@ def _update_posterior_filter_controls(self) -> None: self.posterior_top_percent_spin.setEnabled(True) self.posterior_top_n_spin.setEnabled(True) + def _notify_results_settings_changed(self) -> None: + if self._suspend_visualization_notifications: + return + self.results_settings_changed.emit() + self.visualization_settings_changed.emit() + + def _notify_summary_settings_changed(self) -> None: + if self._suspend_visualization_notifications: + return + self.summary_settings_changed.emit() + self.visualization_settings_changed.emit() + + def _notify_violin_data_settings_changed(self) -> None: + if self._suspend_visualization_notifications: + return + self.violin_data_settings_changed.emit() + self.visualization_settings_changed.emit() + + def _notify_violin_style_settings_changed(self) -> None: + if self._suspend_visualization_notifications: + return + self.violin_style_settings_changed.emit() + self.visualization_settings_changed.emit() + + def _on_bestfit_method_changed(self, _index: int) -> None: + self._notify_results_settings_changed() + + def _on_violin_mode_changed(self, _index: int) -> None: + self._notify_violin_data_settings_changed() + + def _on_violin_sample_source_changed(self, _index: int) -> None: + self._notify_violin_data_settings_changed() + + def _on_weight_order_changed(self, _index: int) -> None: + self._notify_violin_data_settings_changed() + + def _on_violin_value_scale_changed(self, _index: int) -> None: + self._notify_violin_data_settings_changed() + + def _on_violin_outline_width_changed(self, _value: float) -> None: + self._notify_violin_style_settings_changed() + + def _on_posterior_top_percent_changed(self, _value: float) -> None: + self._notify_results_settings_changed() + + def _on_posterior_top_n_changed(self, _value: int) -> None: + self._notify_results_settings_changed() + + def _on_credible_interval_low_changed(self, _value: float) -> None: + self._notify_summary_settings_changed() + + def _on_credible_interval_high_changed(self, _value: float) -> None: + self._notify_summary_settings_changed() + + @classmethod + def _display_violin_samples(cls, samples: object) -> np.ndarray: + display_samples = np.asarray(samples, dtype=float) + if display_samples.ndim == 1: + display_samples = display_samples.reshape(-1, 1) + max_samples = max(int(cls.MAX_VIOLIN_PLOT_SAMPLES), 1) + if display_samples.shape[0] <= max_samples: + return display_samples + sample_indices = np.linspace( + 0, + display_samples.shape[0] - 1, + max_samples, + dtype=int, + ) + return display_samples[sample_indices] + @staticmethod def _build_action_group( title: str, diff --git a/src/saxshell/saxs/ui/main_window.py b/src/saxshell/saxs/ui/main_window.py index 41175b9..8ec60c6 100644 --- a/src/saxshell/saxs/ui/main_window.py +++ b/src/saxshell/saxs/ui/main_window.py @@ -10,33 +10,51 @@ from dataclasses import dataclass from datetime import datetime from pathlib import Path +from typing import cast import numpy as np +from matplotlib import colormaps +from matplotlib.colors import to_hex from PySide6.QtCore import ( QObject, QSettings, QSize, Qt, QThread, + QTimer, QUrl, Signal, Slot, ) -from PySide6.QtGui import QAction, QDesktopServices, QKeySequence, QShortcut +from PySide6.QtGui import ( + QAction, + QColor, + QDesktopServices, + QFont, + QKeySequence, + QShortcut, +) from PySide6.QtWidgets import ( QApplication, QCheckBox, + QColorDialog, + QComboBox, QDialog, QDialogButtonBox, QDoubleSpinBox, QFileDialog, + QFontComboBox, QFormLayout, + QGridLayout, + QGroupBox, QHBoxLayout, QInputDialog, + QLabel, QLineEdit, QMainWindow, QMessageBox, QPushButton, + QScrollArea, QSplitter, QTabWidget, QTextEdit, @@ -57,30 +75,53 @@ load_dream_settings, load_parameter_map, ) -from saxshell.saxs.prefit import PrefitScaleRecommendation, SAXSPrefitWorkflow +from saxshell.saxs.model_report import ( + DreamFilterReportView, + DreamModelReportContext, + PriorHistogramRequest, + ReportComponentPlotData, + ReportComponentSeries, + export_dream_model_report_pptx, +) +from saxshell.saxs.prefit import ( + PrefitParameterEntry, + PrefitScaleRecommendation, + SAXSPrefitWorkflow, +) from saxshell.saxs.prefit.workflow import ( SOLUTE_VOLUME_FRACTION_PARAMETER_NAMES, SOLVENT_VOLUME_FRACTION_PARAMETER_NAMES, + SOLVENT_WEIGHT_PARAMETER_NAMES, normalize_requested_q_range_to_supported, q_range_boundary_tolerance, ) from saxshell.saxs.project_manager import ( ClusterImportResult, + ExperimentalDataSummary, + PowerPointExportSettings, ProjectSettings, SAXSProjectManager, build_project_paths, export_prior_histogram_npy, export_prior_histogram_table, load_built_component_q_range, + load_experimental_data_file, ) -from saxshell.saxs.solute_volume_fraction import ( - DISPLAY_FRACTION_DECIMALS, - SoluteVolumeFractionEstimate, +from saxshell.saxs.solute_volume_fraction import DISPLAY_FRACTION_DECIMALS +from saxshell.saxs.solution_scattering_estimator import ( + SolutionScatteringEstimate, ) from saxshell.saxs.template_installation import ( format_validation_report, install_template_candidate, ) +from saxshell.saxs.ui.branding import ( + build_saxshell_brand_widget, + configure_saxshell_application, + create_saxshell_startup_splash, + load_saxshell_icon, + prepare_saxshell_application_identity, +) from saxshell.saxs.ui.distribution_window import DistributionSetupWindow from saxshell.saxs.ui.dream_tab import DreamTab from saxshell.saxs.ui.dream_violin_export_dialog import DreamViolinExportDialog @@ -89,6 +130,9 @@ from saxshell.saxs.ui.progress_dialog import SAXSProgressDialog from saxshell.saxs.ui.project_setup_tab import ProjectSetupTab from saxshell.saxs.ui.solute_volume_fraction_widget import ( + AttenuationEstimateToolWindow, + FluorescenceEstimateToolWindow, + NumberDensityEstimateToolWindow, SoluteVolumeFractionToolWindow, ) from saxshell.version import __version__ @@ -96,12 +140,21 @@ GITHUB_REPOSITORY_URL = "https://github.com/kewh5868/SAXSShell" CONTACT_EMAIL = "keith.white@colorado.edu" RECENT_PROJECTS_KEY = "recent_project_dirs" +CONSOLE_AUTOSCROLL_KEY = "console_autoscroll_enabled" MAX_RECENT_PROJECTS = 10 REPO_ROOT = Path(__file__).resolve().parents[4] EQUIVALENT_SPHERE_MIX_TEMPLATE_NAMES = { "template_pydream_poly_lma_hs_mix_approx", "template_pydream_poly_lma_hs_legacy", } +POWERPOINT_COLOR_MAP_OPTIONS = ( + "viridis", + "plasma", + "cividis", + "magma", + "inferno", + "turbo", +) @dataclass(frozen=True, slots=True) @@ -141,6 +194,464 @@ class TemplateInstallRequest: model_description: str +class MainUISettingsDialog(QDialog): + def __init__( + self, + *, + dream_settings: DreamRunSettings, + powerpoint_settings: PowerPointExportSettings, + powerpoint_enabled: bool, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._powerpoint_enabled = bool(powerpoint_enabled) + self.setWindowTitle("Main UI Settings") + self.resize(760, 780) + layout = QVBoxLayout(self) + + self.tabs = QTabWidget() + layout.addWidget(self.tabs) + + self._build_dream_output_tab(dream_settings) + self._build_powerpoint_tab(powerpoint_settings) + + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok + | QDialogButtonBox.StandardButton.Cancel + ) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + layout.addWidget(buttons) + + def _build_dream_output_tab(self, settings: DreamRunSettings) -> None: + tab = QWidget() + layout = QVBoxLayout(tab) + form_layout = QFormLayout() + + self.verbose_checkbox = QCheckBox("Verbose sampler output") + self.verbose_checkbox.setChecked(settings.verbose) + self.verbose_checkbox.setToolTip( + "Enable or disable verbose DREAM sampler progress output." + ) + self.interval_spin = QDoubleSpinBox() + self.interval_spin.setRange(0.1, 30.0) + self.interval_spin.setDecimals(1) + self.interval_spin.setSingleStep(0.1) + self.interval_spin.setValue(settings.verbose_output_interval_seconds) + self.interval_spin.setToolTip( + "Minimum number of seconds between DREAM runtime output " + "updates shown in the UI while verbose output is enabled." + ) + self.interval_spin.setEnabled(self.verbose_checkbox.isChecked()) + self.verbose_checkbox.toggled.connect(self.interval_spin.setEnabled) + form_layout.addRow(self.verbose_checkbox) + form_layout.addRow("Output interval (s)", self.interval_spin) + layout.addLayout(form_layout) + layout.addStretch(1) + self.tabs.addTab(tab, "DREAM Output") + + def _build_powerpoint_tab( + self, + settings: PowerPointExportSettings, + ) -> None: + tab = QWidget() + tab_layout = QVBoxLayout(tab) + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + tab_layout.addWidget(scroll_area) + + container = QWidget() + scroll_area.setWidget(container) + layout = QVBoxLayout(container) + + note_label = QLabel( + "Adjust the PowerPoint export styling, slide content, and " + "supplemental files generated with the report." + ) + note_label.setWordWrap(True) + layout.addWidget(note_label) + self._powerpoint_disabled_label = QLabel( + "Load a project to edit PowerPoint export settings." + ) + self._powerpoint_disabled_label.setWordWrap(True) + self._powerpoint_disabled_label.setVisible( + not self._powerpoint_enabled + ) + layout.addWidget(self._powerpoint_disabled_label) + + appearance_group = QGroupBox("Appearance") + appearance_layout = QFormLayout(appearance_group) + self.powerpoint_font_combo = QFontComboBox() + self.powerpoint_font_combo.setCurrentFont(QFont(settings.font_family)) + appearance_layout.addRow("Font family", self.powerpoint_font_combo) + + self.component_cmap_combo = QComboBox() + self._populate_color_map_combo( + self.component_cmap_combo, + settings.component_color_map, + ) + appearance_layout.addRow( + "Component trace palette", + self.component_cmap_combo, + ) + + self.prior_cmap_combo = QComboBox() + self._populate_color_map_combo( + self.prior_cmap_combo, + settings.prior_histogram_color_map, + ) + appearance_layout.addRow( + "Prior histogram palette", + self.prior_cmap_combo, + ) + + self.solvent_sort_prior_cmap_combo = QComboBox() + self._populate_color_map_combo( + self.solvent_sort_prior_cmap_combo, + settings.solvent_sort_histogram_color_map, + ) + appearance_layout.addRow( + "Solvent-sort histogram palette", + self.solvent_sort_prior_cmap_combo, + ) + layout.addWidget(appearance_group) + + colors_group = QGroupBox("Colors") + colors_layout = QGridLayout(colors_group) + colors_layout.setColumnStretch(1, 1) + color_rows = [ + ("Text", settings.text_color, "powerpoint_text_color_button"), + ( + "Experimental trace", + settings.experimental_trace_color, + "powerpoint_experimental_color_button", + ), + ( + "Model trace", + settings.model_trace_color, + "powerpoint_model_color_button", + ), + ( + "Residual trace", + settings.residual_trace_color, + "powerpoint_residual_color_button", + ), + ( + "Solvent trace", + settings.solvent_trace_color, + "powerpoint_solvent_color_button", + ), + ( + "Structure factor trace", + settings.structure_factor_color, + "powerpoint_structure_factor_color_button", + ), + ( + "Table header fill", + settings.table_header_fill, + "powerpoint_table_header_fill_button", + ), + ( + "Table even row fill", + settings.table_even_row_fill, + "powerpoint_table_even_fill_button", + ), + ( + "Table odd row fill", + settings.table_odd_row_fill, + "powerpoint_table_odd_fill_button", + ), + ( + "Table header rule", + settings.table_rule_color, + "powerpoint_table_rule_color_button", + ), + ] + for row_index, (label, color, attribute_name) in enumerate(color_rows): + color_label = QLabel(label) + button = QPushButton() + self._configure_color_button(button, color=color, label=label) + button.clicked.connect( + lambda checked=False, btn=button, title=label: self._choose_color( + btn, + title, + ) + ) + setattr(self, attribute_name, button) + colors_layout.addWidget(color_label, row_index, 0) + colors_layout.addWidget(button, row_index, 1) + layout.addWidget(colors_group) + + content_group = QGroupBox("Slides and Summary Content") + content_layout = QVBoxLayout(content_group) + self.include_prior_histograms_checkbox = QCheckBox( + "Include prior histogram slides" + ) + self.include_prior_histograms_checkbox.setChecked( + settings.include_prior_histograms + ) + content_layout.addWidget(self.include_prior_histograms_checkbox) + + self.include_initial_traces_checkbox = QCheckBox( + "Include initial SAXS traces slide" + ) + self.include_initial_traces_checkbox.setChecked( + settings.include_initial_traces + ) + content_layout.addWidget(self.include_initial_traces_checkbox) + + self.include_prefit_model_checkbox = QCheckBox( + "Include prefit model slide" + ) + self.include_prefit_model_checkbox.setChecked( + settings.include_prefit_model + ) + content_layout.addWidget(self.include_prefit_model_checkbox) + + self.include_prefit_parameters_checkbox = QCheckBox( + "Include prefit parameter table slides" + ) + self.include_prefit_parameters_checkbox.setChecked( + settings.include_prefit_parameters + ) + content_layout.addWidget(self.include_prefit_parameters_checkbox) + + self.include_geometry_table_checkbox = QCheckBox( + "Include computed geometry parameter slides" + ) + self.include_geometry_table_checkbox.setChecked( + settings.include_geometry_table + ) + content_layout.addWidget(self.include_geometry_table_checkbox) + + self.include_estimator_metrics_checkbox = QCheckBox( + "Include estimator metrics slides" + ) + self.include_estimator_metrics_checkbox.setChecked( + settings.include_estimator_metrics + ) + content_layout.addWidget(self.include_estimator_metrics_checkbox) + + self.include_dream_settings_checkbox = QCheckBox( + "Include DREAM settings and assessment slides" + ) + self.include_dream_settings_checkbox.setChecked( + settings.include_dream_settings + ) + content_layout.addWidget(self.include_dream_settings_checkbox) + + self.include_dream_prior_table_checkbox = QCheckBox( + "Include DREAM prior distribution slides" + ) + self.include_dream_prior_table_checkbox.setChecked( + settings.include_dream_prior_table + ) + content_layout.addWidget(self.include_dream_prior_table_checkbox) + + self.include_dream_output_model_checkbox = QCheckBox( + "Include DREAM output model slides" + ) + self.include_dream_output_model_checkbox.setChecked( + settings.include_dream_output_model + ) + content_layout.addWidget(self.include_dream_output_model_checkbox) + + self.include_posterior_comparisons_checkbox = QCheckBox( + "Include posterior comparison plots" + ) + self.include_posterior_comparisons_checkbox.setChecked( + settings.include_posterior_comparisons + ) + content_layout.addWidget(self.include_posterior_comparisons_checkbox) + + self.include_output_summary_checkbox = QCheckBox( + "Include output summary text" + ) + self.include_output_summary_checkbox.setChecked( + settings.include_output_summary + ) + content_layout.addWidget(self.include_output_summary_checkbox) + + self.include_directory_summary_checkbox = QCheckBox( + "Include directory summary text" + ) + self.include_directory_summary_checkbox.setChecked( + settings.include_directory_summary + ) + content_layout.addWidget(self.include_directory_summary_checkbox) + layout.addWidget(content_group) + + output_group = QGroupBox("Supplemental Output Data") + output_layout = QVBoxLayout(output_group) + self.generate_manifest_checkbox = QCheckBox( + "Generate report manifest JSON" + ) + self.generate_manifest_checkbox.setChecked(settings.generate_manifest) + output_layout.addWidget(self.generate_manifest_checkbox) + + self.export_figure_assets_checkbox = QCheckBox( + "Keep rendered figure PNG assets" + ) + self.export_figure_assets_checkbox.setChecked( + settings.export_figure_assets + ) + output_layout.addWidget(self.export_figure_assets_checkbox) + layout.addWidget(output_group) + + self._powerpoint_groups = ( + appearance_group, + colors_group, + content_group, + output_group, + ) + for group in self._powerpoint_groups: + group.setEnabled(self._powerpoint_enabled) + layout.addStretch(1) + self.tabs.addTab(tab, "PowerPoint Export") + + @staticmethod + def _populate_color_map_combo( + combo: QComboBox, + selected_name: str, + ) -> None: + for name in POWERPOINT_COLOR_MAP_OPTIONS: + combo.addItem(name, name) + current_name = str(selected_name).strip() or "viridis" + index = combo.findData(current_name) + combo.setCurrentIndex(index if index >= 0 else 0) + + @staticmethod + def _configure_color_button( + button: QPushButton, + *, + color: str, + label: str, + ) -> None: + normalized = str(color).strip() or "#000000" + qcolor = QColor(normalized) + foreground = "#ffffff" + if qcolor.isValid() and qcolor.lightness() > 128: + foreground = "#000000" + button.setText(normalized.upper()) + button.setToolTip(f"{label}: {normalized.upper()}") + button.setMinimumWidth(120) + button.setStyleSheet( + "QPushButton {" + f"background-color: {normalized};" + f"color: {foreground};" + "border: 1px solid #666666;" + "padding: 4px 8px;" + "}" + ) + + def _choose_color(self, button: QPushButton, label: str) -> None: + chosen = QColorDialog.getColor( + QColor(button.text().strip()), + self, + f"Choose {label.lower()}", + ) + if not chosen.isValid(): + return + self._configure_color_button( + button, + color=chosen.name(), + label=label, + ) + + def dream_output_values(self) -> tuple[bool, float]: + return ( + bool(self.verbose_checkbox.isChecked()), + float(self.interval_spin.value()), + ) + + def powerpoint_settings_value(self) -> PowerPointExportSettings: + return PowerPointExportSettings( + font_family=( + self.powerpoint_font_combo.currentFont().family().strip() + or "Arial" + ), + component_color_map=str( + self.component_cmap_combo.currentData() or "viridis" + ), + prior_histogram_color_map=str( + self.prior_cmap_combo.currentData() or "viridis" + ), + solvent_sort_histogram_color_map=str( + self.solvent_sort_prior_cmap_combo.currentData() or "summer" + ), + text_color=self.powerpoint_text_color_button.text().strip(), + experimental_trace_color=( + self.powerpoint_experimental_color_button.text().strip() + ), + model_trace_color=( + self.powerpoint_model_color_button.text().strip() + ), + residual_trace_color=( + self.powerpoint_residual_color_button.text().strip() + ), + solvent_trace_color=( + self.powerpoint_solvent_color_button.text().strip() + ), + structure_factor_color=( + self.powerpoint_structure_factor_color_button.text().strip() + ), + table_header_fill=( + self.powerpoint_table_header_fill_button.text().strip() + ), + table_even_row_fill=( + self.powerpoint_table_even_fill_button.text().strip() + ), + table_odd_row_fill=( + self.powerpoint_table_odd_fill_button.text().strip() + ), + table_rule_color=( + self.powerpoint_table_rule_color_button.text().strip() + ), + include_prior_histograms=bool( + self.include_prior_histograms_checkbox.isChecked() + ), + include_initial_traces=bool( + self.include_initial_traces_checkbox.isChecked() + ), + include_prefit_model=bool( + self.include_prefit_model_checkbox.isChecked() + ), + include_prefit_parameters=bool( + self.include_prefit_parameters_checkbox.isChecked() + ), + include_geometry_table=bool( + self.include_geometry_table_checkbox.isChecked() + ), + include_estimator_metrics=bool( + self.include_estimator_metrics_checkbox.isChecked() + ), + include_dream_settings=bool( + self.include_dream_settings_checkbox.isChecked() + ), + include_dream_prior_table=bool( + self.include_dream_prior_table_checkbox.isChecked() + ), + include_dream_output_model=bool( + self.include_dream_output_model_checkbox.isChecked() + ), + include_posterior_comparisons=bool( + self.include_posterior_comparisons_checkbox.isChecked() + ), + include_output_summary=bool( + self.include_output_summary_checkbox.isChecked() + ), + include_directory_summary=bool( + self.include_directory_summary_checkbox.isChecked() + ), + generate_manifest=bool( + self.generate_manifest_checkbox.isChecked() + ), + export_figure_assets=bool( + self.export_figure_assets_checkbox.isChecked() + ), + ) + + class InstallModelDialog(QDialog): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) @@ -290,6 +801,12 @@ class SAXSMainWindow(QMainWindow): """Main Qt window for SAXS project setup, prefit, and DREAM refinement.""" + DREAM_REFRESH_DELAY_MS = 75 + DREAM_REFRESH_STYLE = 1 + DREAM_REFRESH_VIOLIN = 2 + DREAM_REFRESH_SUMMARY = 3 + DREAM_REFRESH_FULL = 4 + def __init__( self, initial_project_dir: str | Path | None = None, @@ -324,8 +841,12 @@ def __init__( self._loaded_dream_run_dir: Path | None = None self._warn_on_prefit_template_change = True self._restoring_prefit_template = False + self._prefit_missing_components_warning_shown = False self._updating_deprecated_template_visibility = False self._show_deprecated_templates = False + self._console_autoscroll_enabled = ( + self._load_console_autoscroll_setting() + ) self._ui_scale = 1.0 self._base_font_point_size = self._resolve_base_font_point_size() self._scale_shortcuts: list[QShortcut] = [] @@ -333,9 +854,28 @@ def __init__( self._solute_volume_fraction_tool_window: ( SoluteVolumeFractionToolWindow | None ) = None + self._number_density_tool_window: ( + NumberDensityEstimateToolWindow | None + ) = None + self._attenuation_tool_window: AttenuationEstimateToolWindow | None = ( + None + ) + self._fluorescence_tool_window: ( + FluorescenceEstimateToolWindow | None + ) = None + self._last_solution_scattering_estimate: ( + SolutionScatteringEstimate | None + ) = None self._pending_prefit_sf_approximation_change: ( tuple[str, str] | None ) = None + self._pending_dream_refresh_scope = 0 + self._dream_refresh_timer = QTimer(self) + self._dream_refresh_timer.setSingleShot(True) + self._dream_refresh_timer.setInterval(self.DREAM_REFRESH_DELAY_MS) + self._dream_refresh_timer.timeout.connect( + self._flush_pending_dream_refresh + ) self._build_ui() self._capture_scale_baselines(self) self._register_scale_shortcuts() @@ -352,14 +892,28 @@ def __init__( self.dream_tab.clear_plots() def _build_ui(self) -> None: - self.setWindowTitle("SAXSShell (saxs)") + self.setWindowTitle("SAXSShell") + self.setWindowIcon(load_saxshell_icon()) self.resize(self._default_window_size()) self._build_menu_bar() self.tabs = QTabWidget() + self.tabs.setCornerWidget( + build_saxshell_brand_widget(self.tabs), + Qt.Corner.TopLeftCorner, + ) self.project_setup_tab = ProjectSetupTab() self.prefit_tab = PrefitTab() self.dream_tab = DreamTab() + self.project_setup_tab.set_console_autoscroll_enabled( + self._console_autoscroll_enabled + ) + self.prefit_tab.set_console_autoscroll_enabled( + self._console_autoscroll_enabled + ) + self.dream_tab.set_console_autoscroll_enabled( + self._console_autoscroll_enabled + ) self.tabs.addTab(self.project_setup_tab, "Project Setup") self.tabs.addTab(self.prefit_tab, "SAXS Prefit") self.tabs.addTab(self.dream_tab, "SAXS DREAM Fit") @@ -378,6 +932,9 @@ def _build_ui(self) -> None: self.project_setup_tab.autosave_project_requested.connect( self._autosave_project_from_tab ) + self.project_setup_tab.model_only_mode_changed.connect( + self._on_model_only_mode_changed + ) self.project_setup_tab.scan_clusters_requested.connect( self.scan_clusters_from_tab ) @@ -415,6 +972,9 @@ def _build_ui(self) -> None: self.prefit_tab.show_deprecated_templates_changed.connect( self._on_show_deprecated_templates_changed ) + self.prefit_tab.field_interaction_requested.connect( + self._on_prefit_field_interaction_requested + ) self.prefit_tab.autosave_toggled.connect(self._on_autosave_changed) self.prefit_tab.update_model_requested.connect( self.update_prefit_model @@ -437,6 +997,9 @@ def _build_ui(self) -> None: self.restore_prefit_state ) self.prefit_tab.reset_requested.connect(self.reset_prefit_entries) + self.prefit_tab.parameter_reset_requested.connect( + self.reset_single_prefit_parameter + ) self.prefit_tab.compute_cluster_geometry_requested.connect( self.compute_prefit_cluster_geometry ) @@ -456,7 +1019,7 @@ def _build_ui(self) -> None: self._on_prefit_cluster_geometry_ionic_radius_type_changed ) self.prefit_tab.solute_volume_fraction_widget.estimate_calculated.connect( - self._on_solute_volume_fraction_estimate_calculated + self._on_solution_scattering_estimate_calculated ) self.dream_tab.edit_parameter_map_requested.connect( @@ -474,6 +1037,12 @@ def _build_ui(self) -> None: self.load_selected_results ) self.dream_tab.save_report_requested.connect(self.save_dream_report) + self.dream_tab.recycle_output_requested.connect( + self.recycle_dream_output_to_prefit + ) + self.dream_tab.export_model_report_requested.connect( + self.export_dream_model_report + ) self.dream_tab.save_model_fit_requested.connect( self.save_dream_model_fit ) @@ -484,7 +1053,29 @@ def _build_ui(self) -> None: self._on_dream_settings_preset_changed ) self.dream_tab.visualization_settings_changed.connect( - self._refresh_loaded_dream_results + lambda: self._schedule_dream_results_refresh( + self.DREAM_REFRESH_STYLE + ) + ) + self.dream_tab.results_settings_changed.connect( + lambda: self._schedule_dream_results_refresh( + self.DREAM_REFRESH_FULL + ) + ) + self.dream_tab.summary_settings_changed.connect( + lambda: self._schedule_dream_results_refresh( + self.DREAM_REFRESH_SUMMARY + ) + ) + self.dream_tab.violin_data_settings_changed.connect( + lambda: self._schedule_dream_results_refresh( + self.DREAM_REFRESH_VIOLIN + ) + ) + self.dream_tab.violin_style_settings_changed.connect( + lambda: self._schedule_dream_results_refresh( + self.DREAM_REFRESH_STYLE + ) ) self._refresh_recent_projects_menu() self._update_file_menu_state() @@ -541,6 +1132,24 @@ def _build_menu_bar(self) -> None: ) self.tools_menu.addAction(self.bondanalysis_action) + self.clusterdynamics_action = QAction( + "Open Cluster Dynamics", + self, + ) + self.clusterdynamics_action.triggered.connect( + self._open_clusterdynamics_tool + ) + self.tools_menu.addAction(self.clusterdynamics_action) + + self.clusterdynamicsml_action = QAction( + "Open Cluster Dynamics ML", + self, + ) + self.clusterdynamicsml_action.triggered.connect( + self._open_clusterdynamicsml_tool + ) + self.tools_menu.addAction(self.clusterdynamicsml_action) + self.fullrmc_action = QAction("Open fullrmc Setup", self) self.fullrmc_action.triggered.connect(self._open_fullrmc_tool) self.tools_menu.addAction(self.fullrmc_action) @@ -553,16 +1162,31 @@ def _build_menu_bar(self) -> None: self._open_solute_volume_fraction_tool ) self.tools_menu.addAction(self.volume_fraction_action) + self.number_density_action = QAction( + "Open Number Density Estimate", self + ) + self.number_density_action.triggered.connect( + self._open_number_density_tool + ) + self.tools_menu.addAction(self.number_density_action) + self.attenuation_estimate_action = QAction( + "Open Attenuation Estimate", + self, + ) + self.attenuation_estimate_action.triggered.connect( + self._open_attenuation_tool + ) + self.tools_menu.addAction(self.attenuation_estimate_action) + self.fluorescence_estimate_action = QAction( + "Open Fluorescence Estimate", + self, + ) + self.fluorescence_estimate_action.triggered.connect( + self._open_fluorescence_tool + ) + self.tools_menu.addAction(self.fluorescence_estimate_action) placeholder_specs = [ ("PDF Calculation (Coming Soon)", "PDF Calculation"), - ( - "Number Density Estimate (Coming Soon)", - "Number Density Estimate", - ), - ( - "Bond Association/Dissociation Analysis (Coming Soon)", - "Bond Association/Dissociation Analysis", - ), ] self._placeholder_tool_actions: list[QAction] = [] for label, tool_name in placeholder_specs: @@ -576,14 +1200,27 @@ def _build_menu_bar(self) -> None: self._placeholder_tool_actions.append(action) self.settings_menu = menu_bar.addMenu("Settings") - self.dream_output_settings_action = QAction( - "DREAM Output Settings...", + self.console_autoscroll_action = QAction( + "Autoscroll Console Output", + self, + ) + self.console_autoscroll_action.setCheckable(True) + self.console_autoscroll_action.setChecked( + bool(self._console_autoscroll_enabled) + ) + self.console_autoscroll_action.triggered.connect( + self._toggle_console_autoscroll + ) + self.settings_menu.addAction(self.console_autoscroll_action) + self.main_ui_settings_action = QAction( + "Main UI Settings...", self, ) - self.dream_output_settings_action.triggered.connect( - self._open_dream_output_settings_dialog + self.main_ui_settings_action.triggered.connect( + self._open_main_ui_settings_dialog ) - self.settings_menu.addAction(self.dream_output_settings_action) + self.settings_menu.addAction(self.main_ui_settings_action) + self.dream_output_settings_action = self.main_ui_settings_action self.help_menu = menu_bar.addMenu("Help") self.version_info_action = QAction("Version Information", self) @@ -782,6 +1419,7 @@ def _open_project_from_menu(self) -> None: def load_project(self, project_dir: str | Path) -> None: settings = self.project_manager.load_project(project_dir) + self._last_solution_scattering_estimate = None self.current_settings = settings self._apply_project_settings(settings) self._remember_recent_project(settings.project_dir) @@ -812,6 +1450,8 @@ def scan_clusters_from_tab(self) -> None: def build_project_components(self) -> None: try: settings = self._settings_from_project_tab() + if not self._confirm_default_q_range_for_component_build(): + return self._save_settings( settings, status_message="Project auto-saved before building SAXS components", @@ -832,6 +1472,66 @@ def build_project_components(self) -> None: except Exception as exc: self._show_error("Build failed", str(exc)) + def _current_project_has_built_components(self) -> bool: + project_dir: Path | None + if self.current_settings is not None: + project_dir = Path(self.current_settings.project_dir).expanduser() + else: + project_dir = self.project_setup_tab.project_dir() + if project_dir is None: + return False + paths = build_project_paths(project_dir) + return any(paths.scattering_components_dir.glob("*.txt")) + + def _on_prefit_field_interaction_requested(self) -> None: + if self._current_project_has_built_components(): + self._prefit_missing_components_warning_shown = False + return + message = ( + "Build SAXS components in the Project Setup tab before editing " + "Prefit fields for this model." + ) + if self._prefit_missing_components_warning_shown: + self.statusBar().showMessage(message, 5000) + return + self._prefit_missing_components_warning_shown = True + QMessageBox.warning( + self, + "Build SAXS components first", + message, + ) + self.statusBar().showMessage(message, 5000) + + def _confirm_default_q_range_for_component_build(self) -> bool: + if ( + not self.project_setup_tab.q_range_matches_loaded_experimental_defaults() + ): + return True + default_range = self.project_setup_tab.default_experimental_q_range() + if default_range is None: + return True + q_min, q_max = default_range + response = QMessageBox.warning( + self, + "Build SAXS components with default q-range?", + ( + "The q-range still matches the full experimental-data " + f"default ({q_min:.6g} to {q_max:.6g}). If you intended to " + "crop the SAXS range, adjust q min and q max before " + "building.\n\n" + "Continue building SAXS components with the default q-range?" + ), + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if response == QMessageBox.StandardButton.Yes: + return True + self.statusBar().showMessage( + "SAXS component build canceled so the q-range can be adjusted.", + 5000, + ) + return False + def build_prior_weights(self) -> None: try: settings = self._settings_from_project_tab() @@ -1005,7 +1705,7 @@ def save_project_state(self) -> None: try: settings = self._settings_from_project_tab() saved_path = self._save_settings(settings) - self._apply_project_settings(settings) + self._sync_live_project_settings_after_save(settings) self.project_setup_tab.append_summary( f"Saved project state to {saved_path}" ) @@ -1023,6 +1723,37 @@ def save_project_state(self) -> None: except Exception as exc: self._show_error("Save project state failed", str(exc)) + def _sync_live_project_settings_after_save( + self, + settings: ProjectSettings, + ) -> None: + active_settings = ProjectSettings.from_dict(settings.to_dict()) + self.current_settings = active_settings + self.project_setup_tab.set_project_selected(True) + self._refresh_component_plot() + self._refresh_prior_plot() + + if self.prefit_workflow is not None: + self.prefit_workflow.parameter_entries = ( + self.prefit_tab.parameter_entries() + ) + self.prefit_workflow.apply_project_settings(active_settings) + self.current_settings = self.prefit_workflow.settings + self._load_prefit_preview() + self._refresh_saved_prefit_states() + + if self.dream_workflow is not None: + self.dream_workflow.apply_project_settings(self.current_settings) + self._dream_workflow_project_dir = str( + Path(self.current_settings.project_dir).resolve() + ) + self._invalidate_written_dream_bundle() + self._last_dream_constraint_warning_signature = None + self._refresh_saved_dream_runs() + + self._refresh_model_only_mode_state() + self._update_file_menu_state() + @Slot(str) def _autosave_project_from_tab(self, reason: str) -> None: if self.project_setup_tab.project_dir() is None: @@ -1363,9 +2094,24 @@ def _ensure_dream_progress_dialog(self) -> SAXSProgressDialog: self._dream_progress_dialog = SAXSProgressDialog(self) return self._dream_progress_dialog - def _show_dream_progress_dialog(self, message: str) -> None: - dialog = self._ensure_dream_progress_dialog() - dialog.begin_busy(message, title="SAXS DREAM Progress") + def _show_dream_progress_dialog( + self, + message: str, + *, + total: int | None = None, + unit_label: str = "runs", + title: str = "SAXS DREAM Progress", + ) -> None: + dialog = self._ensure_dream_progress_dialog() + if total is None: + dialog.begin_busy(message, title=title) + return + dialog.begin( + total, + message, + unit_label=unit_label, + title=title, + ) def _close_dream_progress_dialog(self) -> None: if self._dream_progress_dialog is not None: @@ -1609,117 +2355,283 @@ def _refresh_prefit_cluster_geometry_section(self) -> None: def _refresh_prefit_volume_fraction_section(self) -> None: target = self._current_volume_fraction_target() - visible = target is not None + solvent_weight_target = self._current_solvent_weight_target() + visible = self.prefit_workflow is not None self.prefit_tab.set_solute_volume_fraction_visible(visible) - if target is None: - self.prefit_tab.set_solute_volume_fraction_target(None, None) - else: + parameter_name = None + fraction_kind = None + if target is not None: parameter_name, fraction_kind = target - self.prefit_tab.set_solute_volume_fraction_target( - parameter_name, - fraction_kind, - ) - self._sync_volume_fraction_tool_target() + self.prefit_tab.set_solute_volume_fraction_target( + parameter_name, + fraction_kind, + solvent_weight_target, + ) + self._sync_solution_scattering_tool_targets() def _current_volume_fraction_target(self) -> tuple[str, str] | None: if self.prefit_workflow is None: return None return self.prefit_workflow.volume_fraction_estimator_target() - def _sync_volume_fraction_tool_target(self) -> None: + def _current_solvent_weight_target(self) -> str | None: + if self.prefit_workflow is None: + return None + return self.prefit_workflow.solvent_weight_estimator_target() + + def _sync_solution_scattering_tool_targets(self) -> None: target = self._current_volume_fraction_target() if target is None: parameter_name = None fraction_kind = None else: parameter_name, fraction_kind = target - if self._solute_volume_fraction_tool_window is not None: - self._solute_volume_fraction_tool_window.estimator_widget.set_target_parameter( + solvent_weight_parameter = self._current_solvent_weight_target() + for window in ( + self._solute_volume_fraction_tool_window, + self._number_density_tool_window, + self._attenuation_tool_window, + self._fluorescence_tool_window, + ): + if window is None: + continue + window.estimator_widget.set_target_parameter( parameter_name, fraction_kind, + solvent_weight_parameter, ) + def _apply_estimator_parameter_to_prefit( + self, + parameter_name: str, + parameter_value: float, + ) -> None: + current_entry = next( + ( + entry + for entry in self.prefit_tab.parameter_entries() + if entry.name == parameter_name + ), + None, + ) + minimum = 0.0 + maximum = max(parameter_value, 1.0) + if current_entry is not None: + minimum = min(float(current_entry.minimum), 0.0) + maximum = max( + float(current_entry.maximum), + parameter_value, + 1.0, + ) + self.prefit_tab.set_parameter_row( + parameter_name, + value=parameter_value, + minimum=minimum, + maximum=maximum, + vary=False, + ) + @Slot(object) - def _on_solute_volume_fraction_estimate_calculated( + def _on_solution_scattering_estimate_calculated( self, estimate_payload: object, ) -> None: - if not isinstance(estimate_payload, SoluteVolumeFractionEstimate): + if not isinstance(estimate_payload, SolutionScatteringEstimate): return - target = self._current_volume_fraction_target() + self._last_solution_scattering_estimate = estimate_payload widget = self.sender() - if target is None: + applied_notes: list[str] = [] + log_lines = ["Applied solution-scattering estimates."] + if self.prefit_workflow is None: if hasattr(widget, "append_application_note"): - widget.append_application_note( - "No active Prefit solute/solvent fraction parameter is " - "currently available, so this estimate was not applied " - "to the model." + cast(object, widget).append_application_note( + "The calculations completed, but there is no active " + "Prefit workflow to apply them to." ) - self.statusBar().showMessage("Volume fraction estimate calculated") - return - parameter_name, fraction_kind = target - parameter_value = ( - float(estimate_payload.solute_volume_fraction) - if fraction_kind == "solute" - else float(estimate_payload.solvent_volume_fraction) - ) - if self.prefit_workflow is None: + self.statusBar().showMessage( + "Solution-scattering estimate calculated" + ) return + + preview_changed = False try: - current_entry = next( - ( - entry - for entry in self.prefit_tab.parameter_entries() - if entry.name == parameter_name - ), - None, - ) - minimum = 0.0 - maximum = max(parameter_value, 1.0) - if current_entry is not None: - minimum = min(float(current_entry.minimum), 0.0) - maximum = max( - float(current_entry.maximum), + volume_target = self._current_volume_fraction_target() + interaction_estimate = ( + estimate_payload.interaction_contrast_estimate + ) + if interaction_estimate is not None and volume_target is not None: + parameter_name, fraction_kind = volume_target + parameter_value = ( + float( + interaction_estimate.saxs_effective_solute_interaction_ratio + ) + if fraction_kind == "solute" + else float( + interaction_estimate.saxs_effective_solvent_background_ratio + ) + ) + self._apply_estimator_parameter_to_prefit( + parameter_name, parameter_value, - 1.0, ) - self.prefit_tab.set_parameter_row( - parameter_name, - value=parameter_value, - minimum=minimum, - maximum=maximum, - ) - if hasattr(widget, "append_application_note"): - widget.append_application_note( - "Applied estimate to " - f"{parameter_name} = " - f"{parameter_value:.{DISPLAY_FRACTION_DECIMALS}f}." + applied_notes.append( + f"Applied {parameter_name} = " + f"{parameter_value:.{DISPLAY_FRACTION_DECIMALS}f} " + "from the SAXS-effective interaction ratio." ) - self.prefit_tab.append_log( - "Applied solution-based volume fraction estimate.\n" - f"Parameter: {parameter_name}\n" - f"Kind: {fraction_kind}\n" - f"Value: {parameter_value:.{DISPLAY_FRACTION_DECIMALS}f}\n" - f"Solute volume fraction: " - f"{estimate_payload.solute_volume_fraction:.{DISPLAY_FRACTION_DECIMALS}f}\n" - f"Solvent volume fraction: " - f"{estimate_payload.solvent_volume_fraction:.{DISPLAY_FRACTION_DECIMALS}f}" - ) + log_lines.extend( + [ + f"Volume-fraction target: {parameter_name}", + f"Model fraction kind: {fraction_kind}", + ( + "Physical solute-associated volume fraction: " + f"{interaction_estimate.physical_solute_associated_volume_fraction:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ( + "Physical solvent-associated volume fraction: " + f"{interaction_estimate.physical_solvent_associated_volume_fraction:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ( + "SAXS-effective solute interaction ratio: " + f"{interaction_estimate.saxs_effective_solute_interaction_ratio:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ( + "SAXS-effective solvent background ratio: " + f"{interaction_estimate.saxs_effective_solvent_background_ratio:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ( + "Contrast weight factor: " + f"{interaction_estimate.contrast_weight_factor:.6g}" + ), + ] + ) + preview_changed = True + elif ( + estimate_payload.volume_fraction_estimate is not None + and hasattr(widget, "append_application_note") + ): + cast(object, widget).append_application_note( + "Calculated the physical bulk fraction and the " + "SAXS-effective interaction ratio, but the active " + "template does not expose a model-facing solute/solvent " + "interaction fraction parameter." + ) + + solvent_weight_target = self._current_solvent_weight_target() + if ( + estimate_payload.attenuation_estimate is not None + and solvent_weight_target is not None + ): + attenuation_scale = float( + estimate_payload.attenuation_estimate.solvent_scattering_scale_factor + ) + uses_split_fraction_parameter = volume_target is not None + solvent_scale = attenuation_scale + if ( + not uses_split_fraction_parameter + and interaction_estimate is not None + ): + solvent_scale = ( + attenuation_scale + * interaction_estimate.saxs_effective_solvent_background_ratio + ) + self._apply_estimator_parameter_to_prefit( + solvent_weight_target, + solvent_scale, + ) + if ( + not uses_split_fraction_parameter + and interaction_estimate is not None + ): + applied_notes.append( + f"Applied {solvent_weight_target} = " + f"{solvent_scale:.{DISPLAY_FRACTION_DECIMALS}f} " + "from attenuation x the SAXS-effective solvent " + "background ratio." + ) + elif not uses_split_fraction_parameter: + applied_notes.append( + f"Applied {solvent_weight_target} = " + f"{solvent_scale:.{DISPLAY_FRACTION_DECIMALS}f} " + "from attenuation only because the SAXS-effective " + "interaction ratio was not available." + ) + else: + applied_notes.append( + f"Applied {solvent_weight_target} = " + f"{solvent_scale:.{DISPLAY_FRACTION_DECIMALS}f} " + "as the attenuation solvent scale." + ) + log_lines.extend( + [ + f"Solvent-weight target: {solvent_weight_target}", + ( + "Attenuation solvent scale factor: " + f"{attenuation_scale:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ( + "Sample transmission: " + f"{estimate_payload.attenuation_estimate.sample_transmission:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ( + "Neat-solvent transmission: " + f"{estimate_payload.attenuation_estimate.neat_solvent_transmission:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ] + ) + if ( + not uses_split_fraction_parameter + and interaction_estimate is not None + ): + log_lines.append( + "Combined solvent background multiplier: " + f"{solvent_scale:.{DISPLAY_FRACTION_DECIMALS}f}" + ) + elif ( + uses_split_fraction_parameter + and interaction_estimate is not None + ): + log_lines.append( + "SAXS-effective solvent background ratio is carried by " + f"{volume_target[0]} = " + f"{interaction_estimate.saxs_effective_solvent_background_ratio:.{DISPLAY_FRACTION_DECIMALS}f}" + ) + preview_changed = True + elif estimate_payload.attenuation_estimate is not None and hasattr( + widget, "append_application_note" + ): + cast(object, widget).append_application_note( + "Calculated the attenuation-based solvent contribution, " + "but the active template does not expose a solvent " + "background parameter such as solv_w or solvent_scale." + ) + + if applied_notes and hasattr(widget, "append_application_note"): + cast(object, widget).append_application_note( + "\n".join(applied_notes) + ) + + if not preview_changed: + if estimate_payload.fluorescence_estimate is not None: + self.statusBar().showMessage( + "Solution-scattering estimate calculated" + ) + return + self.prefit_workflow.parameter_entries = ( self.prefit_tab.parameter_entries() ) + self.prefit_tab.append_log("\n".join(log_lines)) evaluation = self._load_prefit_preview() if evaluation is None: self.statusBar().showMessage( - "Volume fraction estimate applied; preview is waiting " - "on the remaining template metadata" + "Estimator values applied; preview is waiting on the " + "remaining template metadata" ) else: - self.statusBar().showMessage( - "Volume fraction estimate applied" - ) + self.statusBar().showMessage("Estimator values applied") except Exception as exc: - self._show_error("Apply volume fraction failed", str(exc)) + self._show_error("Apply estimator values failed", str(exc)) def _sync_prefit_cluster_geometry_rows(self) -> None: if ( @@ -2049,10 +2961,8 @@ def update_prefit_model(self) -> None: try: previous_evaluation = self.prefit_tab.current_evaluation() self._sync_prefit_cluster_geometry_rows() - entries = [ - entry for entry in self.prefit_workflow.parameter_entries - ] - self.prefit_tab.populate_parameter_table(entries) + entries = self.prefit_tab.parameter_entries() + self.prefit_workflow.parameter_entries = entries evaluation = self.prefit_workflow.evaluate(entries) self.prefit_tab.plot_evaluation(evaluation) run_config = self.prefit_tab.run_config() @@ -2130,13 +3040,17 @@ def run_prefit(self) -> None: "Build a project and load its SAXS components first.", ) return + if not self.prefit_workflow.can_run_prefit(): + self._show_error( + "Prefit unavailable", + "Disable Model Only Mode and load experimental SAXS data to run a prefit.", + ) + return try: self._sync_prefit_cluster_geometry_rows() config = self.prefit_tab.run_config() - entries = [ - entry for entry in self.prefit_workflow.parameter_entries - ] - self.prefit_tab.populate_parameter_table(entries) + entries = self.prefit_tab.parameter_entries() + self.prefit_workflow.parameter_entries = entries self.prefit_tab.append_log( "Running prefit.\n" f"Template: {self.prefit_workflow.template_spec.name}\n" @@ -2190,6 +3104,12 @@ def save_prefit(self) -> None: "Build a project and load its SAXS components first.", ) return + if not self.prefit_workflow.can_run_prefit(): + self._show_error( + "Save Prefit unavailable", + "Disable Model Only Mode and load experimental SAXS data before saving a prefit report.", + ) + return try: self._sync_prefit_cluster_geometry_rows() entries = self.prefit_tab.parameter_entries() @@ -2257,23 +3177,28 @@ def save_prefit_plot_data(self) -> None: ) if destination is None: return - columns = [ - "q", - "experimental_intensity", - "model_intensity", - "residual", - "solvent_intensity", - "solvent_contribution", - ] - matrix = np.column_stack( - [ - np.asarray(evaluation.q_values, dtype=float), + columns = ["q"] + matrix_columns = [np.asarray(evaluation.q_values, dtype=float)] + if evaluation.experimental_intensities is not None: + columns.append("experimental_intensity") + matrix_columns.append( np.asarray( evaluation.experimental_intensities, dtype=float, - ), - np.asarray(evaluation.model_intensities, dtype=float), - np.asarray(evaluation.residuals, dtype=float), + ) + ) + columns.append("model_intensity") + matrix_columns.append( + np.asarray(evaluation.model_intensities, dtype=float) + ) + if evaluation.residuals is not None: + columns.append("residual") + matrix_columns.append( + np.asarray(evaluation.residuals, dtype=float) + ) + columns.extend(["solvent_intensity", "solvent_contribution"]) + matrix_columns.extend( + [ ( np.asarray(evaluation.solvent_intensities, dtype=float) if evaluation.solvent_intensities is not None @@ -2289,6 +3214,18 @@ def save_prefit_plot_data(self) -> None: ), ] ) + columns.append("structure_factor") + matrix_columns.append( + ( + np.asarray( + evaluation.structure_factor_trace, + dtype=float, + ) + if evaluation.structure_factor_trace is not None + else np.full_like(evaluation.q_values, np.nan) + ) + ) + matrix = np.column_stack(matrix_columns) if destination.suffix.lower() == ".csv": self._write_prefit_plot_csv( destination, @@ -2327,6 +3264,52 @@ def reset_prefit_entries(self) -> None: ) self.update_prefit_model() + def reset_single_prefit_parameter( + self, + structure: str, + motif: str, + parameter_name: str, + ) -> None: + if self.prefit_workflow is None: + return + try: + default_entry = next( + ( + entry + for entry in self.prefit_workflow.load_template_reset_entries() + if ( + entry.structure == structure + and entry.motif == motif + and entry.name == parameter_name + ) + ), + None, + ) + if default_entry is None: + raise ValueError( + f"No template-default entry is available for {parameter_name}." + ) + self.prefit_tab.set_parameter_row( + parameter_name, + structure=structure, + motif=motif, + value=default_entry.value, + minimum=default_entry.minimum, + maximum=default_entry.maximum, + vary=default_entry.vary, + ) + self.prefit_workflow.parameter_entries = ( + self.prefit_tab.parameter_entries() + ) + self.prefit_tab.append_log( + "Reset individual parameter to the template-default prefit " + f"preset.\nParameter: {parameter_name}" + ) + self.update_prefit_model() + self.statusBar().showMessage(f"Reset {parameter_name}") + except Exception as exc: + self._show_error("Reset parameter failed", str(exc)) + def set_best_prefit_parameters(self) -> None: if self.prefit_workflow is None: self._show_error( @@ -2859,6 +3842,8 @@ def _load_dream_results_from_run_dir( parameter_map_entries = load_parameter_map(parameter_map_path) except Exception: parameter_map_entries = [] + self._dream_refresh_timer.stop() + self._pending_dream_refresh_scope = 0 self._last_results_loader = None self.dream_tab.set_settings(display_settings, preset_name=None) self._last_results_loader = SAXSDreamResultsLoader( @@ -2948,6 +3933,231 @@ def save_dream_report(self) -> None: except Exception as exc: self._show_error("Save report failed", str(exc)) + def recycle_dream_output_to_prefit(self) -> None: + if self._last_results_loader is None: + self._show_error( + "Recycle DREAM output failed", + "Load DREAM results first.", + ) + return + if self.prefit_workflow is None: + self._show_error( + "Recycle DREAM output failed", + "Load or build a project first so the Prefit workflow is available.", + ) + return + try: + settings = self.dream_tab.settings_payload() + summary = self._last_results_loader.get_summary( + bestfit_method=settings.bestfit_method, + posterior_filter_mode=settings.posterior_filter_mode, + posterior_top_percent=settings.posterior_top_percent, + posterior_top_n=settings.posterior_top_n, + credible_interval_low=settings.credible_interval_low, + credible_interval_high=settings.credible_interval_high, + ) + current_entries = self.prefit_tab.parameter_entries() + updated_entries = [] + summary_lookup = { + str(name): float(summary.bestfit_params[index]) + for index, name in enumerate(summary.full_parameter_names) + } + matched_names: list[str] = [] + unmatched_prefit_names: list[str] = [] + for entry in current_entries: + copied_entry = PrefitParameterEntry.from_dict(entry.to_dict()) + if copied_entry.name in summary_lookup: + copied_entry.value = float( + summary_lookup[copied_entry.name] + ) + matched_names.append(copied_entry.name) + else: + unmatched_prefit_names.append(copied_entry.name) + updated_entries.append(copied_entry) + if not matched_names: + raise ValueError( + "The loaded DREAM result did not contain any parameters " + "that match the active Prefit table." + ) + + self.prefit_tab.populate_parameter_table(updated_entries) + self.prefit_workflow.parameter_entries = list(updated_entries) + self._invalidate_dream_workflow_cache() + self.tabs.setCurrentWidget(self.prefit_tab) + self.prefit_tab.update_button.setFocus( + Qt.FocusReason.OtherFocusReason + ) + + unmatched_dream_names = [ + str(name) + for name in summary.full_parameter_names + if str(name) not in {entry.name for entry in current_entries} + ] + log_lines = [ + "Recycled DREAM output into Prefit.", + f"DREAM run: {summary.run_dir}", + ( + "Best-fit selection: " + f"{settings.bestfit_method} with " + f"{self._describe_posterior_filter(settings)}" + ), + ( + "Matched Prefit parameters: " + f"{len(matched_names)} / {len(current_entries)}" + ), + ( + "Prefit preview refresh: deferred to keep recycle " + "responsive. Click Update Model when you're ready to " + "rerender the Prefit plot." + ), + ] + if ( + self.prefit_workflow.template_spec.name + != self._last_results_loader.template_name + ): + log_lines.append( + "Template mismatch note: active Prefit template is " + f"{self.prefit_workflow.template_spec.name}, while the " + f"loaded DREAM run used {self._last_results_loader.template_name}. " + "Only overlapping parameter names were recycled." + ) + if unmatched_prefit_names: + preview = ", ".join(unmatched_prefit_names[:8]) + if len(unmatched_prefit_names) > 8: + preview += ", ..." + log_lines.append( + "Prefit-only parameters left unchanged: " + preview + ) + if unmatched_dream_names: + preview = ", ".join(unmatched_dream_names[:8]) + if len(unmatched_dream_names) > 8: + preview += ", ..." + log_lines.append( + "DREAM-only parameters not copied: " + preview + ) + self.prefit_tab.append_log("\n".join(log_lines)) + self.dream_tab.append_log( + "Recycled the current DREAM best fit into the Prefit tab.\n" + + "\n".join(log_lines[1:]) + ) + self.statusBar().showMessage( + "DREAM output copied into Prefit; preview refresh deferred" + ) + except Exception as exc: + self._show_error("Recycle DREAM output failed", str(exc)) + + def export_dream_model_report(self) -> None: + if self._last_results_loader is None: + self._show_error( + "Export model report failed", + "Load DREAM results first.", + ) + return + if self.current_settings is None: + self._show_error( + "Export model report failed", + "Load or build a project first.", + ) + return + progress_started = False + progress_total = 1 + wait_message = ( + "Generating DREAM model report PowerPoint. Please wait..." + ) + try: + settings = self.dream_tab.settings_payload() + paths = build_project_paths(self.current_settings.project_dir) + paths.reports_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + base_name = f"dream_model_report_{timestamp}" + output_path = paths.reports_dir / f"{base_name}.pptx" + asset_dir = paths.reports_dir / f"{base_name}_assets" + self.dream_tab.append_log(wait_message) + self.dream_tab.begin_progress( + progress_total, + wait_message, + unit_label="steps", + ) + self._show_dream_progress_dialog( + wait_message, + total=progress_total, + unit_label="steps", + title="DREAM Report Export", + ) + self.statusBar().showMessage(wait_message) + QApplication.processEvents() + progress_started = True + context = self._build_dream_model_report_context( + settings=settings, + output_path=output_path, + asset_dir=asset_dir, + ) + + def _on_export_progress( + processed: int, + total: int, + message: str, + ) -> None: + nonlocal progress_total + progress_total = max(int(total), 1) + self.dream_tab.update_progress( + processed, + progress_total, + message, + unit_label="steps", + ) + if ( + self._dream_progress_dialog is not None + and self._dream_progress_dialog.isVisible() + ): + self._dream_progress_dialog.update_progress( + processed, + progress_total, + message, + unit_label="steps", + ) + self.statusBar().showMessage(message) + QApplication.processEvents() + + result = export_dream_model_report_pptx( + context, + progress_callback=_on_export_progress, + ) + self.dream_tab.finish_progress( + "DREAM model report exported.", + total=progress_total, + unit_label="steps", + ) + log_lines = [ + "Exported DREAM model report to:", + f"{result.report_path}", + ] + if result.manifest_path is not None: + log_lines.append(f"Manifest: {result.manifest_path}") + if result.figure_paths: + log_lines.append( + "The report assets folder contains the rendered figures " + "used to assemble the PowerPoint." + ) + else: + log_lines.append( + "Supplemental rendered figure assets were disabled for " + "this export." + ) + self.dream_tab.append_log("\n".join(log_lines)) + self.statusBar().showMessage("DREAM model report exported") + except Exception as exc: + if progress_started: + self.dream_tab.finish_progress( + "DREAM model report export failed.", + total=progress_total, + unit_label="steps", + ) + self._show_error("Export model report failed", str(exc)) + finally: + if progress_started: + self._close_dream_progress_dialog() + def _build_dream_export_context( self, settings: DreamRunSettings, @@ -3110,6 +4320,9 @@ def _dream_export_metadata_payload( "includes_solvent_contribution": bool( model_plot.solvent_contribution is not None ), + "includes_structure_factor": bool( + model_plot.structure_factor_trace is not None + ), "fit_metrics": { "rmse": float(model_plot.rmse), "mean_abs_residual": float(model_plot.mean_abs_residual), @@ -3308,6 +4521,15 @@ def _export_dream_model_fit_bundle( dtype=float, ) ) + structure_factor = ( + np.asarray(model_plot.structure_factor_trace, dtype=float) + if model_plot.structure_factor_trace is not None + else np.full_like( + np.asarray(model_plot.q_values, dtype=float), + np.nan, + dtype=float, + ) + ) np.savetxt( output_path, np.column_stack( @@ -3316,12 +4538,13 @@ def _export_dream_model_fit_bundle( model_plot.experimental_intensities, model_plot.model_intensities, solvent_contribution, + structure_factor, ] ), delimiter=",", header=( "q,experimental_intensity,model_intensity," - "solvent_contribution" + "solvent_contribution,structure_factor" ), comments="", ) @@ -3673,15 +4896,883 @@ def save_dream_violin_data(self) -> None: save_pkl=bool(export_options.save_pkl), auto_generated=False, ) - self.dream_tab.append_log( - "Exported DREAM violin data bundle to:\n" - + "\n".join(str(path) for path in saved_paths) - + "\nThis condensed export lives in exported_results/data. " - "Full DREAM run artifacts remain in the DREAM run folder." + self.dream_tab.append_log( + "Exported DREAM violin data bundle to:\n" + + "\n".join(str(path) for path in saved_paths) + + "\nThis condensed export lives in exported_results/data. " + "Full DREAM run artifacts remain in the DREAM run folder." + ) + self.statusBar().showMessage("DREAM violin data exported") + except Exception as exc: + self._show_error("Save violin data failed", str(exc)) + + def _build_dream_model_report_context( + self, + *, + settings: DreamRunSettings, + output_path: Path, + asset_dir: Path, + ) -> DreamModelReportContext: + if self.current_settings is None or self._last_results_loader is None: + raise RuntimeError("Load DREAM results first.") + powerpoint_settings = self._effective_powerpoint_export_settings() + summary, model_plot, violin_plot, plot_payload = ( + self._build_dream_export_context(settings) + ) + project_paths = build_project_paths(self.current_settings.project_dir) + prefit_entries = tuple( + self.prefit_tab.parameter_entries() + if self.prefit_workflow is not None + else [] + ) + prefit_evaluation = None + if self.prefit_workflow is not None: + try: + prefit_evaluation = self.prefit_workflow.evaluate( + list(prefit_entries) + ) + except Exception: + prefit_evaluation = None + prefit_statistics = self._latest_prefit_statistics_payload() + assessments = tuple(self._report_dream_filter_assessments(settings)) + filter_views = self._build_dream_report_filter_views(settings) + prior_requests = tuple( + self._build_report_prior_requests(powerpoint_settings) + ) + q_values = np.asarray(model_plot.q_values, dtype=float) + q_range_text = self._format_selected_q_range_text(q_values) + supported_q_range = load_built_component_q_range( + self.current_settings.project_dir + ) + supported_q_range_text = ( + None + if supported_q_range is None + else ( + f"{float(supported_q_range[0]):.6g} to " + f"{float(supported_q_range[1]):.6g}" + ) + ) + ( + template_display_name, + template_module_path, + model_equation_text, + model_context_lines, + model_definition_lines, + model_reference_lines, + ) = self._build_report_template_details( + template_name=str(model_plot.template_name), + q_range_text=q_range_text, + supported_q_range_text=supported_q_range_text, + q_sampling_text=self._report_q_sampling_text(), + prefit_parameter_count=len(prefit_entries), + dream_active_parameter_names=tuple( + str(name).strip() + for name in summary.active_parameter_names + if str(name).strip() + ), + includes_solvent=bool(model_plot.solvent_contribution is not None), + includes_structure_factor=bool( + model_plot.structure_factor_trace is not None + ), + ) + output_summary_lines = [ + f"Report file: {output_path}", + f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", + f"Template: {model_plot.template_name}", + f"Best-fit method: {settings.bestfit_method}", + ( + "Posterior filter: " + f"{self._describe_posterior_filter(settings)}" + ), + f"Posterior samples kept: {summary.posterior_sample_count}", + f"DREAM RMSE: {model_plot.rmse:.6g}", + f"DREAM Mean |res|: {model_plot.mean_abs_residual:.6g}", + f"DREAM R^2: {model_plot.r_squared:.6g}", + ] + if ( + prefit_evaluation is not None + and prefit_evaluation.experimental_intensities is not None + ): + prefit_metrics = self._fit_quality_metrics_from_curves( + np.asarray( + prefit_evaluation.experimental_intensities, + dtype=float, + ), + np.asarray(prefit_evaluation.model_intensities, dtype=float), + ) + output_summary_lines.extend( + [ + f"Prefit RMSE: {prefit_metrics.rmse:.6g}", + f"Prefit Mean |res|: {prefit_metrics.mean_abs_residual:.6g}", + f"Prefit R^2: {prefit_metrics.r_squared:.6g}", + ] + ) + directory_lines = [ + f"Project directory: {project_paths.project_dir}", + f"Exported data directory: {project_paths.exported_data_dir}", + f"Exported plots directory: {project_paths.exported_plots_dir}", + f"Prefit directory: {project_paths.prefit_dir}", + f"DREAM run directory: {summary.run_dir}", + ( + "Prior weights source: " + f"{self._report_prior_json_path() or 'Not available'}" + ), + ] + if ( + powerpoint_settings.export_figure_assets + or powerpoint_settings.generate_manifest + ): + directory_lines.insert(0, f"Report assets: {asset_dir}") + if powerpoint_settings.generate_manifest: + directory_lines.insert( + 1, + f"Report manifest: {asset_dir / 'report_manifest.json'}", + ) + return DreamModelReportContext( + output_path=output_path, + asset_dir=asset_dir, + project_name=self.current_settings.project_name, + project_dir=Path(self.current_settings.project_dir).resolve(), + generated_at=datetime.now(), + powerpoint_settings=powerpoint_settings, + user_q_range_text=q_range_text, + supported_q_range_text=supported_q_range_text, + q_sampling_text=self._report_q_sampling_text(), + template_name=str(model_plot.template_name), + template_display_name=template_display_name, + template_module_path=template_module_path, + model_equation_text=model_equation_text, + model_context_lines=tuple(model_context_lines), + model_definition_lines=tuple(model_definition_lines), + model_reference_lines=tuple(model_reference_lines), + prior_histograms=prior_requests, + component_plot_without_solvent=self._build_report_component_plot_data( + include_solvent=False, + powerpoint_settings=powerpoint_settings, + ), + component_plot_with_solvent=self._build_report_component_plot_data( + include_solvent=True, + powerpoint_settings=powerpoint_settings, + ), + prefit_evaluation=prefit_evaluation, + prefit_parameter_entries=prefit_entries, + prefit_statistics=prefit_statistics, + cluster_geometry_rows=tuple(self._report_cluster_geometry_rows()), + solution_scattering_estimate=( + self._current_report_solution_scattering_estimate() + ), + dream_settings=self._copy_dream_settings(settings), + dream_summary=summary, + dream_model_plot=model_plot, + dream_violin_plot=violin_plot, + dream_violin_payload=plot_payload, + dream_parameter_map_entries=tuple( + DreamParameterEntry.from_dict(dict(entry)) + for entry in self._last_results_loader.parameter_map_entries + if isinstance(entry, dict) + ), + dream_filter_assessments=assessments, + dream_filter_views=filter_views, + output_summary_lines=tuple(output_summary_lines), + directory_lines=tuple(directory_lines), + ) + + def _build_report_template_details( + self, + *, + template_name: str, + q_range_text: str, + supported_q_range_text: str | None, + q_sampling_text: str, + prefit_parameter_count: int, + dream_active_parameter_names: tuple[str, ...], + includes_solvent: bool, + includes_structure_factor: bool, + ) -> tuple[str, Path | None, str | None, list[str], list[str], list[str]]: + try: + spec = load_template_spec(template_name) + except Exception: + return ( + template_name, + None, + None, + [ + f"Template name: {template_name}", + "Template metadata could not be loaded for this report.", + f"User selected q-range: {q_range_text}", + ( + "Supported component q-range: " + f"{supported_q_range_text or 'Unavailable'}" + ), + f"q-grid: {q_sampling_text}", + ], + [], + [], + ) + + description_sections = self._report_named_sections_from_text( + spec.description + ) + source_sections = self._report_named_sections_from_template_source( + spec.module_path + ) + model_equation_lines = self._report_template_section_lines( + source_sections, + description_sections, + "model equation", + ) + model_equation_text = ( + " ".join( + line.strip() for line in model_equation_lines if line.strip() + ) + or None + ) + + context_lines = [ + f"Template display name: {spec.display_name}", + f"Template key: {spec.name}", + f"Template module: {spec.module_path}", + f"LMFit entrypoint: {spec.lmfit_model_name}", + f"pyDREAM entrypoint: {spec.dream_model_name}", + f"LMFit inputs: {', '.join(spec.lmfit_inputs) or 'None'}", + f"pyDREAM inputs: {', '.join(spec.dream_inputs) or 'None'}", + ( + "Parameter columns: " + f"{', '.join(spec.param_columns) or 'None'}" + ), + f"Declared static template parameters: {len(spec.parameters)}", + f"Prefit parameter rows in this report: {prefit_parameter_count}", + ( + "DREAM active parameters: " + f"{', '.join(dream_active_parameter_names) or 'None'}" + ), + f"User selected q-range: {q_range_text}", + ( + "Supported component q-range: " + f"{supported_q_range_text or 'Unavailable'}" + ), + f"q-grid: {q_sampling_text}", + ( + "Best-fit plot includes solvent contribution: " + f"{'yes' if includes_solvent else 'no'}" + ), + ( + "Best-fit plot includes structure factor trace: " + f"{'yes' if includes_structure_factor else 'no'}" + ), + ( + "Cluster geometry metadata support: " + f"{'enabled' if spec.cluster_geometry_support.supported else 'disabled'}" + ), + ] + if spec.cluster_geometry_support.supported: + context_lines.extend( + [ + ( + "Allowed structure-factor approximations: " + f"{', '.join(spec.cluster_geometry_support.allowed_sf_approximations)}" + ), + ( + "Cluster metadata fields: " + f"{', '.join(spec.cluster_geometry_support.metadata_fields)}" + ), + ( + "Dynamic geometry parameters: " + f"{'yes' if spec.cluster_geometry_support.dynamic_parameters else 'no'}" + ), + ( + "Runtime bindings: " + + ", ".join( + ( + f"{binding.runtime_name} <- " + f"{binding.metadata_field}" + ) + for binding in ( + spec.cluster_geometry_support.runtime_bindings + ) + ) + ), + ] + ) + + definition_lines: list[str] = [] + for heading in ( + "purpose", + "scientific scope", + "structure factor", + "form factor", + "model organization", + "internal abundance normalization", + "model parameters", + "parameter definitions", + "fitting guidance", + "likelihood convention", + "required pydream globals", + "practical notes", + "cluster geometry metadata", + ): + section_lines = self._report_template_section_lines( + source_sections, + description_sections, + heading, + ) + if not section_lines: + continue + if definition_lines: + definition_lines.append("") + definition_lines.append(f"{heading.title()}:") + definition_lines.extend(section_lines) + + reference_lines = self._report_template_section_lines( + source_sections, + description_sections, + "relevant resources", + "references", + ) + if not reference_lines and spec.metadata_path is not None: + reference_lines = [ + "No explicit literature references were declared in the " + f"template metadata or source comments for {spec.name}.", + ] + + return ( + spec.display_name, + spec.module_path, + model_equation_text, + context_lines, + definition_lines, + reference_lines, + ) + + @staticmethod + def _report_named_sections_from_template_source( + module_path: Path, + ) -> dict[str, list[str]]: + try: + raw_lines = module_path.read_text(encoding="utf-8").splitlines() + except Exception: + return {} + comment_lines: list[str] = [] + for raw_line in raw_lines: + stripped = raw_line.strip() + if stripped.startswith("def ") or stripped.startswith("class "): + break + if stripped.startswith("#"): + comment_lines.append(stripped[1:].strip()) + elif not stripped: + comment_lines.append("") + return SAXSMainWindow._report_named_sections_from_lines(comment_lines) + + @staticmethod + def _report_named_sections_from_text( + text: str, + ) -> dict[str, list[str]]: + return SAXSMainWindow._report_named_sections_from_lines( + text.splitlines() + ) + + @staticmethod + def _report_named_sections_from_lines( + raw_lines: list[str], + ) -> dict[str, list[str]]: + recognized_headings = { + "purpose", + "scientific scope", + "structure factor", + "form factor", + "model organization", + "model equation", + "internal abundance normalization", + "model parameters", + "parameter definitions", + "fitting guidance", + "likelihood convention", + "required pydream globals", + "practical notes", + "cluster geometry metadata", + "relevant resources", + "references", + } + sections: dict[str, list[str]] = {} + current_heading: str | None = None + current_lines: list[str] = [] + for raw_line in raw_lines: + stripped = str(raw_line).strip() + normalized_heading = stripped[:-1].strip().lower() + if ( + stripped.endswith(":") + and ":" not in stripped[:-1] + and not stripped.startswith("-") + and normalized_heading in recognized_headings + ): + if current_heading is not None: + sections[current_heading] = [ + line for line in current_lines if line or line == "" + ] + current_heading = normalized_heading + current_lines = [] + continue + if current_heading is None: + continue + current_lines.append(stripped) + if current_heading is not None: + sections[current_heading] = [ + line for line in current_lines if line or line == "" + ] + normalized: dict[str, list[str]] = {} + for heading, lines in sections.items(): + cleaned_lines: list[str] = [] + previous_blank = True + for line in lines: + if not line: + if not previous_blank: + cleaned_lines.append("") + previous_blank = True + continue + cleaned_lines.append(line) + previous_blank = False + while cleaned_lines and not cleaned_lines[0]: + cleaned_lines.pop(0) + while cleaned_lines and not cleaned_lines[-1]: + cleaned_lines.pop() + normalized[heading] = cleaned_lines + return normalized + + @staticmethod + def _report_template_section_lines( + source_sections: dict[str, list[str]], + description_sections: dict[str, list[str]], + *section_names: str, + ) -> list[str]: + for section_name in section_names: + normalized = str(section_name).strip().lower() + lines = source_sections.get(normalized) + if lines: + return list(lines) + lines = description_sections.get(normalized) + if lines: + return list(lines) + return [] + + def _report_prior_json_path(self) -> Path | None: + current_prior_path = self.project_setup_tab.current_prior_json_path() + if current_prior_path is not None and current_prior_path.is_file(): + return current_prior_path + if self.current_settings is None: + return None + candidate = ( + build_project_paths(self.current_settings.project_dir).project_dir + / "md_prior_weights.json" + ) + return candidate if candidate.is_file() else None + + def _report_secondary_element(self) -> str | None: + return self.project_setup_tab.selected_prior_secondary_element() + + def _build_report_prior_requests( + self, + powerpoint_settings: PowerPointExportSettings, + ) -> list[PriorHistogramRequest]: + prior_json_path = self._report_prior_json_path() + if prior_json_path is None: + return [] + secondary_element = self._report_secondary_element() + return [ + PriorHistogramRequest( + title="Structure Fraction Histogram", + json_path=prior_json_path, + mode="structure_fraction", + cmap=powerpoint_settings.prior_histogram_color_map, + ), + PriorHistogramRequest( + title="Atom Fraction Histogram", + json_path=prior_json_path, + mode="atom_fraction", + cmap=powerpoint_settings.prior_histogram_color_map, + ), + PriorHistogramRequest( + title="Solvent Sort Structure Fraction Histogram", + json_path=prior_json_path, + mode="solvent_sort_structure_fraction", + cmap=powerpoint_settings.solvent_sort_histogram_color_map, + secondary_element=secondary_element, + ), + PriorHistogramRequest( + title="Solvent Sort Atom Fraction Histogram", + json_path=prior_json_path, + mode="solvent_sort_atom_fraction", + cmap=powerpoint_settings.solvent_sort_histogram_color_map, + secondary_element=secondary_element, + ), + ] + + def _report_data_summary( + self, + *, + solvent: bool, + ) -> ExperimentalDataSummary | None: + if self.current_settings is None: + return None + settings = self.current_settings + if solvent: + preferred_paths = [ + settings.copied_solvent_data_file, + settings.solvent_data_path, + ] + skiprows = int(settings.solvent_header_rows) + q_column = settings.solvent_q_column + intensity_column = settings.solvent_intensity_column + error_column = settings.solvent_error_column + else: + preferred_paths = [ + settings.copied_experimental_data_file, + settings.experimental_data_path, + ] + skiprows = int(settings.experimental_header_rows) + q_column = settings.experimental_q_column + intensity_column = settings.experimental_intensity_column + error_column = settings.experimental_error_column + resolved_path = next( + ( + Path(candidate).expanduser().resolve() + for candidate in preferred_paths + if candidate + and Path(candidate).expanduser().resolve().is_file() + ), + None, + ) + if resolved_path is None: + return None + return load_experimental_data_file( + resolved_path, + skiprows=skiprows, + q_column=q_column, + intensity_column=intensity_column, + error_column=error_column, + ) + + def _build_report_component_plot_data( + self, + *, + include_solvent: bool, + powerpoint_settings: PowerPointExportSettings, + ) -> ReportComponentPlotData | None: + if self.current_settings is None: + return None + paths = build_project_paths(self.current_settings.project_dir) + component_paths = sorted(paths.scattering_components_dir.glob("*.txt")) + experimental_summary = self._report_data_summary(solvent=False) + solvent_summary = ( + self._report_data_summary(solvent=True) + if include_solvent + else None + ) + if not component_paths and experimental_summary is None: + return None + try: + cmap = colormaps[powerpoint_settings.component_color_map] + except Exception: + cmap = colormaps["viridis"] + if len(component_paths) <= 1: + positions = np.asarray([0.68], dtype=float) + else: + positions = np.linspace(0.15, 0.9, len(component_paths)) + component_series: list[ReportComponentSeries] = [] + for component_path, position in zip( + component_paths, + positions, + strict=False, + ): + raw_data = np.loadtxt(component_path, comments="#") + if raw_data.ndim == 1: + raw_data = raw_data.reshape(1, -1) + component_series.append( + ReportComponentSeries( + label=component_path.stem, + q_values=np.asarray(raw_data[:, 0], dtype=float), + intensities=np.asarray(raw_data[:, 1], dtype=float), + color=to_hex(cmap(float(position)), keep_alpha=False), + ) + ) + title = ( + "Initial SAXS traces without solvent" + if not include_solvent + else "Initial SAXS traces with solvent" + ) + return ReportComponentPlotData( + title=title, + selected_q_min=( + float(self.current_settings.q_min) + if self.current_settings.q_min is not None + else None + ), + selected_q_max=( + float(self.current_settings.q_max) + if self.current_settings.q_max is not None + else None + ), + use_experimental_grid=bool( + self.current_settings.use_experimental_grid + and not self.current_settings.model_only_mode + ), + log_x=bool( + self.project_setup_tab.component_log_x_checkbox.isChecked() + ), + log_y=bool( + self.project_setup_tab.component_log_y_checkbox.isChecked() + ), + experimental_q_values=( + None + if experimental_summary is None + else np.asarray(experimental_summary.q_values, dtype=float) + ), + experimental_intensities=( + None + if experimental_summary is None + else np.asarray(experimental_summary.intensities, dtype=float) + ), + solvent_q_values=( + None + if solvent_summary is None + else np.asarray(solvent_summary.q_values, dtype=float) + ), + solvent_intensities=( + None + if solvent_summary is None + else np.asarray(solvent_summary.intensities, dtype=float) + ), + component_series=tuple(component_series), + ) + + def _latest_prefit_statistics_payload(self) -> dict[str, object]: + if self.current_settings is None: + return {} + paths = build_project_paths(self.current_settings.project_dir) + state_path = paths.prefit_dir / "prefit_state.json" + if not state_path.is_file(): + return {} + payload = json.loads(state_path.read_text(encoding="utf-8")) + statistics = dict(payload.get("statistics", {})) + statistics["saved_at"] = str(payload.get("saved_at", "")) + latest_snapshot = next( + iter(sorted(paths.prefit_dir.glob("prefit_*"), reverse=True)), + None, + ) + if latest_snapshot is not None: + statistics["snapshot_dir"] = str(latest_snapshot) + latest_report_path = latest_snapshot / "prefit_report.txt" + if latest_report_path.is_file(): + statistics["report_path"] = str(latest_report_path) + return statistics + + def _report_cluster_geometry_rows(self) -> list[object]: + if self.prefit_workflow is None: + return [] + try: + return list(self.prefit_workflow.cluster_geometry_rows()) + except Exception: + return [] + + @staticmethod + def _estimate_section_count( + estimate: SolutionScatteringEstimate, + ) -> int: + return sum( + section is not None + for section in ( + estimate.number_density_estimate, + estimate.volume_fraction_estimate, + estimate.attenuation_estimate, + estimate.fluorescence_estimate, + ) + ) + + def _current_report_solution_scattering_estimate( + self, + ) -> SolutionScatteringEstimate | None: + candidates: list[SolutionScatteringEstimate] = [] + embedded_estimate = ( + self.prefit_tab.solute_volume_fraction_widget.current_estimate() + ) + if embedded_estimate is not None: + candidates.append(embedded_estimate) + for window in ( + self._solute_volume_fraction_tool_window, + self._number_density_tool_window, + self._attenuation_tool_window, + self._fluorescence_tool_window, + ): + if window is None: + continue + estimate = window.estimator_widget.current_estimate() + if estimate is not None: + candidates.append(estimate) + if self._last_solution_scattering_estimate is not None: + candidates.append(self._last_solution_scattering_estimate) + if not candidates: + return None + return max( + candidates, + key=lambda estimate: ( + self._estimate_section_count(estimate), + estimate is self._last_solution_scattering_estimate, + ), + ) + + def _report_dream_filter_assessments( + self, + settings: DreamRunSettings, + ) -> list[dict[str, object]]: + if self._last_dream_filter_assessments: + return [dict(item) for item in self._last_dream_filter_assessments] + assessments, recommendation = self._evaluate_dream_posterior_filters( + settings + ) + if recommendation is not None: + self._last_dream_filter_recommendation = dict(recommendation) + self._last_dream_filter_assessments = [ + dict(item) for item in assessments + ] + return [dict(item) for item in assessments] + + def _build_dream_report_filter_views( + self, + settings: DreamRunSettings, + ) -> tuple[DreamFilterReportView, ...]: + if self._last_results_loader is None: + return () + views: list[DreamFilterReportView] = [] + for mode, label in ( + ("all_post_burnin", "All Post-burnin"), + ("top_percent_logp", "Top % by Log-posterior"), + ("top_n_logp", "Top N by Log-posterior"), + ): + candidate_settings = self._copy_dream_settings(settings) + candidate_settings.posterior_filter_mode = mode + summary = self._last_results_loader.get_summary( + bestfit_method=candidate_settings.bestfit_method, + posterior_filter_mode=candidate_settings.posterior_filter_mode, + posterior_top_percent=candidate_settings.posterior_top_percent, + posterior_top_n=candidate_settings.posterior_top_n, + credible_interval_low=candidate_settings.credible_interval_low, + credible_interval_high=( + candidate_settings.credible_interval_high + ), + ) + model_plot = self._last_results_loader.build_model_fit_data( + bestfit_method=candidate_settings.bestfit_method, + posterior_filter_mode=candidate_settings.posterior_filter_mode, + posterior_top_percent=candidate_settings.posterior_top_percent, + posterior_top_n=candidate_settings.posterior_top_n, + credible_interval_low=candidate_settings.credible_interval_low, + credible_interval_high=( + candidate_settings.credible_interval_high + ), + ) + violin_plot = self._last_results_loader.build_violin_data( + mode=self._effective_dream_violin_mode(candidate_settings), + posterior_filter_mode=candidate_settings.posterior_filter_mode, + posterior_top_percent=candidate_settings.posterior_top_percent, + posterior_top_n=candidate_settings.posterior_top_n, + credible_interval_low=candidate_settings.credible_interval_low, + credible_interval_high=( + candidate_settings.credible_interval_high + ), + sample_source=candidate_settings.violin_sample_source, + weight_order=candidate_settings.violin_weight_order, + ) + weights_violin_plot = self._last_results_loader.build_violin_data( + mode="weights_only", + posterior_filter_mode=candidate_settings.posterior_filter_mode, + posterior_top_percent=candidate_settings.posterior_top_percent, + posterior_top_n=candidate_settings.posterior_top_n, + credible_interval_low=candidate_settings.credible_interval_low, + credible_interval_high=( + candidate_settings.credible_interval_high + ), + sample_source=candidate_settings.violin_sample_source, + weight_order=candidate_settings.violin_weight_order, + ) + effective_radii_violin_plot = ( + self._last_results_loader.build_violin_data( + mode="effective_radii_only", + posterior_filter_mode=( + candidate_settings.posterior_filter_mode + ), + posterior_top_percent=( + candidate_settings.posterior_top_percent + ), + posterior_top_n=candidate_settings.posterior_top_n, + credible_interval_low=( + candidate_settings.credible_interval_low + ), + credible_interval_high=( + candidate_settings.credible_interval_high + ), + sample_source=candidate_settings.violin_sample_source, + weight_order=candidate_settings.violin_weight_order, + ) + ) + title = label + is_active = mode == settings.posterior_filter_mode + if is_active: + title += " [Active]" + views.append( + DreamFilterReportView( + title=title, + description=self._describe_posterior_filter( + candidate_settings + ), + filter_mode=mode, + is_active=is_active, + summary=summary, + model_plot=model_plot, + violin_plot=violin_plot, + violin_payload=self.dream_tab.prepare_violin_plot_payload( + summary, + violin_plot, + ), + weights_violin_payload=( + self.dream_tab.prepare_violin_plot_payload( + summary, + weights_violin_plot, + ) + ), + effective_radii_violin_payload=( + self.dream_tab.prepare_violin_plot_payload( + summary, + effective_radii_violin_plot, + ) + ), + ) ) - self.statusBar().showMessage("DREAM violin data exported") - except Exception as exc: - self._show_error("Save violin data failed", str(exc)) + return tuple(views) + + def _report_q_sampling_text(self) -> str: + if self.current_settings is None: + return "Unavailable" + if self.current_settings.model_only_mode: + return f"Model-only resampled grid ({self.current_settings.q_points or 0} points)" + if self.current_settings.use_experimental_grid: + return "Experimental grid" + if self.current_settings.q_points is not None: + return f"Resampled grid ({self.current_settings.q_points} points)" + return "Project q-grid" + + def _format_selected_q_range_text(self, q_values: np.ndarray) -> str: + lower = ( + float(self.current_settings.q_min) + if self.current_settings is not None + and self.current_settings.q_min is not None + else float(np.min(q_values)) + ) + upper = ( + float(self.current_settings.q_max) + if self.current_settings is not None + and self.current_settings.q_max is not None + else float(np.max(q_values)) + ) + return f"{lower:.6g} to {upper:.6g}" @staticmethod def _effective_dream_violin_mode(settings: DreamRunSettings) -> str: @@ -3689,6 +5780,10 @@ def _effective_dream_violin_mode(settings: DreamRunSettings) -> str: return "weights_only" if settings.violin_value_scale_mode == "normalized_all": return "all_parameters" + if settings.violin_value_scale_mode == "effective_radii_only": + return "effective_radii_only" + if settings.violin_value_scale_mode == "additional_parameters_only": + return "additional_parameters_only" return settings.violin_parameter_mode def save_prior_plot_data_as(self) -> None: @@ -3907,22 +6002,47 @@ def _build_prefit_plot_export_metadata( method: str, max_nfev: int, ) -> dict[str, object]: - residuals = np.asarray(evaluation.residuals, dtype=float) q_values = np.asarray(evaluation.q_values, dtype=float) - chi_square = float(np.sum(residuals**2)) - dof = max( - len(q_values) - sum(1 for entry in entries if entry.vary), - 1, - ) - reduced_chi_square = chi_square / dof - experimental = np.asarray( - evaluation.experimental_intensities, - dtype=float, - ) - ss_total = float(np.sum((experimental - np.mean(experimental)) ** 2)) - r_squared = ( - 1.0 - chi_square / ss_total if ss_total > 0.0 else float("nan") - ) + fit_metrics: dict[str, object] + if ( + evaluation.residuals is None + or evaluation.experimental_intensities is None + ): + fit_metrics = { + "mode": "model_only", + "chi_square": None, + "reduced_chi_square": None, + "r_squared": None, + "residual_rms": None, + "mean_absolute_residual": None, + } + else: + residuals = np.asarray(evaluation.residuals, dtype=float) + chi_square = float(np.sum(residuals**2)) + dof = max( + len(q_values) - sum(1 for entry in entries if entry.vary), + 1, + ) + reduced_chi_square = chi_square / dof + experimental = np.asarray( + evaluation.experimental_intensities, + dtype=float, + ) + ss_total = float( + np.sum((experimental - np.mean(experimental)) ** 2) + ) + fit_metrics = { + "mode": "fit", + "chi_square": chi_square, + "reduced_chi_square": reduced_chi_square, + "r_squared": ( + 1.0 - chi_square / ss_total + if ss_total > 0.0 + else float("nan") + ), + "residual_rms": float(np.sqrt(np.mean(residuals**2))), + "mean_absolute_residual": float(np.mean(np.abs(residuals))), + } return { "exported_at": datetime.now().isoformat(), "project_dir": str(self.current_settings.project_dir), @@ -3942,18 +6062,16 @@ def _build_prefit_plot_export_metadata( "point_count": int(len(q_values)), "q_min": float(np.min(q_values)) if len(q_values) else None, "q_max": float(np.max(q_values)) if len(q_values) else None, + "includes_structure_factor": bool( + evaluation.structure_factor_trace is not None + ), }, - "fit_metrics": { - "chi_square": chi_square, - "reduced_chi_square": reduced_chi_square, - "r_squared": r_squared, - "residual_rms": float(np.sqrt(np.mean(residuals**2))), - "mean_absolute_residual": float(np.mean(np.abs(residuals))), - }, + "fit_metrics": fit_metrics, "parameter_entries": [entry.to_dict() for entry in entries], } def _apply_project_settings(self, settings: ProjectSettings) -> None: + self._prefit_missing_components_warning_shown = False template_specs = self._template_specs_for_dropdown( include_deprecated=self._show_deprecated_templates, selected_names=[settings.selected_model_template or ""], @@ -3982,8 +6100,15 @@ def _apply_project_settings(self, settings: ProjectSettings) -> None: "Prefit summary is not available yet.\n" f"{exc}" ) self.prefit_tab.set_saved_states([], None) + self._refresh_model_only_mode_state() try: - self._load_dream_workflow() + if not settings.model_only_mode: + self._load_dream_workflow() + else: + raise ValueError( + "DREAM is disabled in Model Only Mode. Disable Model " + "Only Mode and add experimental SAXS data to enable DREAM." + ) except Exception as exc: self.dream_workflow = None self.dream_tab.set_available_saved_runs([]) @@ -3994,14 +6119,68 @@ def _apply_project_settings(self, settings: ProjectSettings) -> None: "DREAM summary is not available yet.\n" f"{exc}" ) self.dream_tab.clear_plots() + self._refresh_model_only_mode_state() self._update_file_menu_state() + def _set_dream_tab_enabled(self, enabled: bool) -> None: + dream_index = self.tabs.indexOf(self.dream_tab) + if dream_index < 0: + return + self.tabs.setTabEnabled(dream_index, enabled) + if not enabled and self.tabs.currentIndex() == dream_index: + self.tabs.setCurrentWidget(self.prefit_tab) + + def _refresh_model_only_mode_state(self) -> None: + model_only = bool( + self.current_settings is not None + and self.current_settings.model_only_mode + ) + self.project_setup_tab.set_model_only_mode(model_only) + self.prefit_tab.set_model_only_mode(model_only) + prefit_enabled = ( + self.prefit_workflow is not None + and self.prefit_workflow.can_run_prefit() + ) + self.prefit_tab.set_prefit_execution_enabled(prefit_enabled) + dream_enabled = ( + self.current_settings is not None + and not model_only + and self.prefit_workflow is not None + and self.prefit_workflow.can_run_prefit() + ) + self._set_dream_tab_enabled(dream_enabled) + + @Slot(bool) + def _on_model_only_mode_changed(self, enabled: bool) -> None: + if self.current_settings is None: + return + self.current_settings.model_only_mode = bool(enabled) + if enabled: + self.current_settings.use_experimental_grid = False + if self.current_settings.q_points is None: + self.current_settings.q_points = 500 + if self.prefit_workflow is not None: + self.prefit_workflow.parameter_entries = ( + self.prefit_tab.parameter_entries() + ) + self.prefit_workflow.set_model_only_mode(enabled) + self.current_settings = self.prefit_workflow.settings + self._load_prefit_preview() + self.prefit_tab.set_log_text(self._format_prefit_console_intro()) + self._invalidate_dream_workflow_cache() + self._refresh_model_only_mode_state() + self.statusBar().showMessage( + "Model Only Mode enabled" + if enabled + else "Model Only Mode disabled" + ) + def _settings_from_project_tab(self) -> ProjectSettings: project_dir = self.project_setup_tab.project_dir() if project_dir is None: raise ValueError("Select a project directory.") base = ( - self.current_settings + ProjectSettings.from_dict(self.current_settings.to_dict()) if self.current_settings is not None else ProjectSettings( project_name=project_dir.name, @@ -4010,6 +6189,7 @@ def _settings_from_project_tab(self) -> ProjectSettings: ) base.project_name = project_dir.name base.project_dir = str(project_dir) + base.model_only_mode = self.project_setup_tab.model_only_mode() base.clusters_dir = ( str(self.project_setup_tab.clusters_dir()) if self.project_setup_tab.clusters_dir() is not None @@ -4116,6 +6296,7 @@ def _load_prefit_workflow(self) -> SAXSPrefitWorkflow: "Loaded the Best Prefit preset from the project file." ) self._refresh_saved_prefit_states() + self._refresh_model_only_mode_state() return self.prefit_workflow def _apply_prefit_template_fallback( @@ -4128,6 +6309,8 @@ def _apply_prefit_template_fallback( ) self.prefit_tab.set_templates(template_specs, selected_template) self.prefit_tab.set_autosave(settings.autosave_prefits) + self.prefit_tab.set_model_only_mode(settings.model_only_mode) + self.prefit_tab.set_prefit_execution_enabled(False) if selected_template: self._restoring_prefit_template = True try: @@ -4142,11 +6325,24 @@ def _apply_prefit_template_fallback( except Exception: template_spec = None target = self._volume_fraction_target_for_template_spec(template_spec) - self.prefit_tab.set_solute_volume_fraction_visible(target is not None) + solvent_weight_target = self._solvent_weight_target_for_template_spec( + template_spec + ) + self.prefit_tab.set_solute_volume_fraction_visible( + template_spec is not None + ) if target is None: - self.prefit_tab.set_solute_volume_fraction_target(None, None) + self.prefit_tab.set_solute_volume_fraction_target( + None, + None, + solvent_weight_target, + ) else: - self.prefit_tab.set_solute_volume_fraction_target(*target) + self.prefit_tab.set_solute_volume_fraction_target( + *target, + solvent_weight_target, + ) + self._sync_solution_scattering_tool_targets() supports_cluster_geometry = bool( template_spec is not None and template_spec.cluster_geometry_support.supported @@ -4189,9 +6385,37 @@ def _volume_fraction_target_for_template_spec( return candidate, "solvent" return None + @staticmethod + def _solvent_weight_target_for_template_spec( + template_spec, + ) -> str | None: + if template_spec is None: + return None + parameter_names = { + str(parameter.name).strip() + for parameter in template_spec.parameters + if str(parameter.name).strip() + } + for candidate in SOLVENT_WEIGHT_PARAMETER_NAMES: + if candidate in parameter_names: + return candidate + return None + def _load_dream_workflow(self) -> SAXSDreamWorkflow: if self.current_settings is None: raise ValueError("No project is currently loaded.") + if self.current_settings.model_only_mode: + raise ValueError( + "DREAM is disabled in Model Only Mode. Disable Model Only " + "Mode and add experimental SAXS data to enable DREAM." + ) + if ( + self.prefit_workflow is not None + and not self.prefit_workflow.can_run_prefit() + ): + raise ValueError( + "DREAM requires experimental SAXS data and an enabled prefit workflow." + ) project_dir = str(Path(self.current_settings.project_dir).resolve()) is_new_project = self._dream_workflow_project_dir != project_dir if self.dream_workflow is None or is_new_project: @@ -4447,6 +6671,63 @@ def _copy_dream_settings( ) -> DreamRunSettings: return DreamRunSettings.from_dict(settings.to_dict()) + def _load_console_autoscroll_setting(self) -> bool: + raw_value = self._recent_projects_settings().value( + CONSOLE_AUTOSCROLL_KEY, + True, + ) + if isinstance(raw_value, bool): + return raw_value + return str(raw_value).strip().lower() not in { + "", + "0", + "false", + "no", + "off", + } + + def _toggle_console_autoscroll(self, enabled: bool) -> None: + self._set_console_autoscroll_enabled(enabled, persist=True) + + def _set_console_autoscroll_enabled( + self, + enabled: bool, + *, + persist: bool, + ) -> None: + self._console_autoscroll_enabled = bool(enabled) + if hasattr(self, "console_autoscroll_action"): + self.console_autoscroll_action.blockSignals(True) + self.console_autoscroll_action.setChecked( + self._console_autoscroll_enabled + ) + self.console_autoscroll_action.blockSignals(False) + if hasattr(self, "project_setup_tab"): + self.project_setup_tab.set_console_autoscroll_enabled( + self._console_autoscroll_enabled + ) + if hasattr(self, "prefit_tab"): + self.prefit_tab.set_console_autoscroll_enabled( + self._console_autoscroll_enabled + ) + if hasattr(self, "dream_tab"): + self.dream_tab.set_console_autoscroll_enabled( + self._console_autoscroll_enabled + ) + if persist: + self._recent_projects_settings().setValue( + CONSOLE_AUTOSCROLL_KEY, + self._console_autoscroll_enabled, + ) + self.statusBar().showMessage( + "Console autoscroll " + + ( + "enabled" + if self._console_autoscroll_enabled + else "disabled" + ) + ) + def _recent_projects_settings(self) -> QSettings: return QSettings("SAXShell", "SAXS") @@ -4618,6 +6899,44 @@ def _open_bondanalysis_tool(self) -> None: else: self.statusBar().showMessage("Opened bond analysis") + def _open_clusterdynamics_tool(self) -> None: + from saxshell.clusterdynamics.ui.main_window import ( + ClusterDynamicsMainWindow, + ) + + project_dir = None + if self.current_settings is not None: + project_dir = Path(self.current_settings.project_dir).resolve() + window = ClusterDynamicsMainWindow(initial_project_dir=project_dir) + window.show() + window.raise_() + self._child_tool_windows.append(window) + if project_dir is not None: + self.statusBar().showMessage( + f"Opened cluster dynamics for {project_dir}" + ) + else: + self.statusBar().showMessage("Opened cluster dynamics") + + def _open_clusterdynamicsml_tool(self) -> None: + from saxshell.clusterdynamicsml.ui.main_window import ( + ClusterDynamicsMLMainWindow, + ) + + project_dir = None + if self.current_settings is not None: + project_dir = Path(self.current_settings.project_dir).resolve() + window = ClusterDynamicsMLMainWindow(initial_project_dir=project_dir) + window.show() + window.raise_() + self._child_tool_windows.append(window) + if project_dir is not None: + self.statusBar().showMessage( + f"Opened cluster dynamics AI for {project_dir}" + ) + else: + self.statusBar().showMessage("Opened cluster dynamics AI") + def _open_fullrmc_tool(self) -> None: from saxshell.fullrmc.ui.main_window import RMCSetupMainWindow @@ -4635,24 +6954,62 @@ def _open_fullrmc_tool(self) -> None: else: self.statusBar().showMessage("Opened fullrmc setup") - def _open_solute_volume_fraction_tool(self) -> None: - window = SoluteVolumeFractionToolWindow() + def _open_solution_scattering_tool_window( + self, + *, + attribute_name: str, + factory, + status_message: str, + ) -> None: + existing_window = getattr(self, attribute_name) + if existing_window is not None: + existing_window.show() + existing_window.raise_() + existing_window.activateWindow() + self.statusBar().showMessage(status_message) + return + + window = factory() window.estimator_widget.estimate_calculated.connect( - self._on_solute_volume_fraction_estimate_calculated + self._on_solution_scattering_estimate_calculated ) - self._solute_volume_fraction_tool_window = window - self._sync_volume_fraction_tool_target() + setattr(self, attribute_name, window) + self._sync_solution_scattering_tool_targets() window.destroyed.connect( - lambda *_args: setattr( - self, - "_solute_volume_fraction_tool_window", - None, - ) + lambda *_args, name=attribute_name: setattr(self, name, None) ) window.show() window.raise_() self._child_tool_windows.append(window) - self.statusBar().showMessage("Opened volume fraction estimate") + self.statusBar().showMessage(status_message) + + def _open_solute_volume_fraction_tool(self) -> None: + self._open_solution_scattering_tool_window( + attribute_name="_solute_volume_fraction_tool_window", + factory=SoluteVolumeFractionToolWindow, + status_message="Opened volume fraction estimate", + ) + + def _open_number_density_tool(self) -> None: + self._open_solution_scattering_tool_window( + attribute_name="_number_density_tool_window", + factory=NumberDensityEstimateToolWindow, + status_message="Opened number density estimate", + ) + + def _open_attenuation_tool(self) -> None: + self._open_solution_scattering_tool_window( + attribute_name="_attenuation_tool_window", + factory=AttenuationEstimateToolWindow, + status_message="Opened attenuation estimate", + ) + + def _open_fluorescence_tool(self) -> None: + self._open_solution_scattering_tool_window( + attribute_name="_fluorescence_tool_window", + factory=FluorescenceEstimateToolWindow, + status_message="Opened fluorescence estimate", + ) def _show_placeholder_tool_message(self, tool_name: str) -> None: QMessageBox.information( @@ -4665,47 +7022,36 @@ def _show_placeholder_tool_message(self, tool_name: str) -> None: ) self.statusBar().showMessage(f"{tool_name} is not available yet", 5000) - def _open_dream_output_settings_dialog(self) -> None: - settings = self.dream_tab.settings_payload() - dialog = QDialog(self) - dialog.setWindowTitle("DREAM Output Settings") - layout = QVBoxLayout(dialog) - - form_layout = QFormLayout() - verbose_checkbox = QCheckBox("Verbose sampler output") - verbose_checkbox.setChecked(settings.verbose) - verbose_checkbox.setToolTip( - "Enable or disable verbose DREAM sampler progress output." - ) - interval_spin = QDoubleSpinBox() - interval_spin.setRange(0.1, 30.0) - interval_spin.setDecimals(1) - interval_spin.setSingleStep(0.1) - interval_spin.setValue(settings.verbose_output_interval_seconds) - interval_spin.setToolTip( - "Minimum number of seconds between DREAM runtime output " - "updates shown in the UI while verbose output is enabled." + def _effective_powerpoint_export_settings( + self, + ) -> PowerPointExportSettings: + if self.current_settings is None: + return PowerPointExportSettings() + return PowerPointExportSettings.from_dict( + self.current_settings.powerpoint_export_settings.to_dict() ) - interval_spin.setEnabled(verbose_checkbox.isChecked()) - verbose_checkbox.toggled.connect(interval_spin.setEnabled) - form_layout.addRow(verbose_checkbox) - form_layout.addRow("Output interval (s)", interval_spin) - layout.addLayout(form_layout) - buttons = QDialogButtonBox( - QDialogButtonBox.StandardButton.Ok - | QDialogButtonBox.StandardButton.Cancel + def _open_main_ui_settings_dialog(self) -> None: + dialog = MainUISettingsDialog( + dream_settings=self.dream_tab.settings_payload(), + powerpoint_settings=self._effective_powerpoint_export_settings(), + powerpoint_enabled=self.current_settings is not None, + parent=self, ) - buttons.accepted.connect(dialog.accept) - buttons.rejected.connect(dialog.reject) - layout.addWidget(buttons) - if dialog.exec() != QDialog.DialogCode.Accepted: return + verbose, interval_seconds = dialog.dream_output_values() self._apply_dream_output_settings( - verbose=verbose_checkbox.isChecked(), - interval_seconds=interval_spin.value(), + verbose=verbose, + interval_seconds=interval_seconds, ) + if self.current_settings is not None: + self._apply_powerpoint_export_settings( + dialog.powerpoint_settings_value() + ) + + def _open_dream_output_settings_dialog(self) -> None: + self._open_main_ui_settings_dialog() def _apply_dream_output_settings( self, @@ -4732,6 +7078,46 @@ def _apply_dream_output_settings( ) self.statusBar().showMessage("DREAM output settings updated") + def _apply_powerpoint_export_settings( + self, + settings: PowerPointExportSettings, + ) -> None: + if self.current_settings is None: + return + normalized = PowerPointExportSettings.from_dict(settings.to_dict()) + enabled_section_count = sum( + int(value) + for value in ( + normalized.include_prior_histograms, + normalized.include_initial_traces, + normalized.include_prefit_model, + normalized.include_prefit_parameters, + normalized.include_geometry_table, + normalized.include_estimator_metrics, + normalized.include_dream_settings, + normalized.include_dream_prior_table, + normalized.include_dream_output_model, + normalized.include_posterior_comparisons, + normalized.include_output_summary, + normalized.include_directory_summary, + ) + ) + self.current_settings.powerpoint_export_settings = normalized + self.dream_tab.append_log( + "Updated PowerPoint export settings.\n" + f"Font: {normalized.font_family}\n" + f"Component palette: {normalized.component_color_map}\n" + "Prior palettes: " + f"{normalized.prior_histogram_color_map} / " + f"{normalized.solvent_sort_histogram_color_map}\n" + f"Slides enabled: {enabled_section_count}/12\n" + f"Manifest export: {'on' if normalized.generate_manifest else 'off'}\n" + "Rendered figure assets: " + f"{'kept' if normalized.export_figure_assets else 'temporary only'}\n" + "Save the project if you want to persist this change." + ) + self.statusBar().showMessage("PowerPoint export settings updated") + def _show_version_information(self) -> None: QMessageBox.information( self, @@ -4933,12 +7319,36 @@ def _on_prior_histogram_window_destroyed( if open_window is not window ] - def _refresh_loaded_dream_results(self) -> None: + def _schedule_dream_results_refresh(self, scope: int) -> None: + if self._last_results_loader is None: + return + self._pending_dream_refresh_scope = max( + int(self._pending_dream_refresh_scope), + int(scope), + ) + self._dream_refresh_timer.start() + + def _flush_pending_dream_refresh(self) -> None: + scope = int(self._pending_dream_refresh_scope) + self._pending_dream_refresh_scope = 0 + if scope <= 0: + return + self._refresh_loaded_dream_results(scope=scope) + + def _refresh_loaded_dream_results( + self, + *, + scope: int | None = None, + ) -> None: if self._last_results_loader is None: return + refresh_scope = ( + int(scope) if scope is not None else self.DREAM_REFRESH_FULL + ) try: settings = self.dream_tab.settings_payload() - summary = self._last_results_loader.get_summary( + loader = self._last_results_loader + summary = loader.get_summary( bestfit_method=settings.bestfit_method, posterior_filter_mode=settings.posterior_filter_mode, posterior_top_percent=settings.posterior_top_percent, @@ -4946,7 +7356,54 @@ def _refresh_loaded_dream_results(self) -> None: credible_interval_low=settings.credible_interval_low, credible_interval_high=settings.credible_interval_high, ) - model_plot = self._last_results_loader.build_model_fit_data( + self.dream_tab.set_summary_text( + self._format_dream_summary(summary, settings=settings) + ) + if refresh_scope <= self.DREAM_REFRESH_STYLE: + if self.dream_tab.current_violin_plot_data() is None: + violin_plot = loader.build_violin_data( + mode=self._effective_dream_violin_mode(settings), + posterior_filter_mode=settings.posterior_filter_mode, + posterior_top_percent=settings.posterior_top_percent, + posterior_top_n=settings.posterior_top_n, + credible_interval_low=settings.credible_interval_low, + credible_interval_high=settings.credible_interval_high, + sample_source=settings.violin_sample_source, + weight_order=settings.violin_weight_order, + ) + self.dream_tab.plot_violin_plot(summary, violin_plot) + else: + self.dream_tab.redraw_current_violin_plot() + return + if refresh_scope == self.DREAM_REFRESH_SUMMARY: + violin_plot = self.dream_tab.current_violin_plot_data() + if violin_plot is None: + violin_plot = loader.build_violin_data( + mode=self._effective_dream_violin_mode(settings), + posterior_filter_mode=settings.posterior_filter_mode, + posterior_top_percent=settings.posterior_top_percent, + posterior_top_n=settings.posterior_top_n, + credible_interval_low=settings.credible_interval_low, + credible_interval_high=settings.credible_interval_high, + sample_source=settings.violin_sample_source, + weight_order=settings.violin_weight_order, + ) + self.dream_tab.plot_violin_plot(summary, violin_plot) + return + if refresh_scope == self.DREAM_REFRESH_VIOLIN: + violin_plot = loader.build_violin_data( + mode=self._effective_dream_violin_mode(settings), + posterior_filter_mode=settings.posterior_filter_mode, + posterior_top_percent=settings.posterior_top_percent, + posterior_top_n=settings.posterior_top_n, + credible_interval_low=settings.credible_interval_low, + credible_interval_high=settings.credible_interval_high, + sample_source=settings.violin_sample_source, + weight_order=settings.violin_weight_order, + ) + self.dream_tab.plot_violin_plot(summary, violin_plot) + return + model_plot = loader.build_model_fit_data( bestfit_method=settings.bestfit_method, posterior_filter_mode=settings.posterior_filter_mode, posterior_top_percent=settings.posterior_top_percent, @@ -4954,7 +7411,7 @@ def _refresh_loaded_dream_results(self) -> None: credible_interval_low=settings.credible_interval_low, credible_interval_high=settings.credible_interval_high, ) - violin_plot = self._last_results_loader.build_violin_data( + violin_plot = loader.build_violin_data( mode=self._effective_dream_violin_mode(settings), posterior_filter_mode=settings.posterior_filter_mode, posterior_top_percent=settings.posterior_top_percent, @@ -4964,9 +7421,6 @@ def _refresh_loaded_dream_results(self) -> None: sample_source=settings.violin_sample_source, weight_order=settings.violin_weight_order, ) - self.dream_tab.set_summary_text( - self._format_dream_summary(summary, settings=settings) - ) self.dream_tab.plot_model_fit(model_plot) self.dream_tab.plot_violin_plot(summary, violin_plot) except Exception as exc: @@ -5360,12 +7814,23 @@ def _format_prefit_console_intro(self) -> str: try: q_values = self.prefit_workflow._component_q_values() except Exception: - q_values = np.asarray( - self.prefit_workflow.experimental_data.q_values, - dtype=float, - ) + if self.prefit_workflow.experimental_data is not None: + q_values = np.asarray( + self.prefit_workflow.experimental_data.q_values, + dtype=float, + ) + else: + q_values = np.asarray([], dtype=float) run_config = self.prefit_tab.run_config() - if settings.use_experimental_grid: + if q_values.size == 0: + grid_text = "The active q-grid is not available yet." + elif settings.model_only_mode: + grid_text = ( + "Model Only Mode is active. Using a forward-model q-grid with " + f"{len(q_values)} points from {float(q_values.min()):.6g} to " + f"{float(q_values.max()):.6g}." + ) + elif settings.use_experimental_grid: grid_text = ( "Using the experimental q-grid cropped to the nearest " "available q-points inside the requested range " @@ -5397,6 +7862,11 @@ def _format_prefit_console_intro(self) -> str: if preview_block_reason is not None else "" ) + + ( + "Prefit fitting is disabled while Model Only Mode is active.\n" + if settings.model_only_mode + else "" + ) + "Recommended order: refine scale first, then scale + offset. " "Component weights w<##> are not recommended for prefit refinement.\n" f"Default minimizer: {run_config.method}\n" @@ -5414,10 +7884,7 @@ def _format_prefit_summary( ) -> str: if self.prefit_workflow is None: return "Prefit summary is not available." - residuals = np.asarray(evaluation.residuals, dtype=float) q_values = np.asarray(evaluation.q_values, dtype=float) - rms_residual = float(np.sqrt(np.mean(residuals**2))) - mean_abs_residual = float(np.mean(np.abs(residuals))) lines = [ "Prefit summary:", f"Template: {self.prefit_workflow.template_spec.name}", @@ -5426,8 +7893,6 @@ def _format_prefit_summary( f"q-range: {float(q_values.min()):.6g} to " f"{float(q_values.max()):.6g}" ), - f"Residual RMS: {rms_residual:.6g}", - f"Mean |residual|: {mean_abs_residual:.6g}", f"Configured minimizer: {self.prefit_tab.run_config().method}", f"Configured max nfev: {self.prefit_tab.run_config().max_nfev}", ( @@ -5439,6 +7904,24 @@ def _format_prefit_summary( ) ), ] + if ( + evaluation.residuals is None + or evaluation.experimental_intensities is None + ): + lines.extend( + [ + "Mode: Model Only", + "Experimental fit metrics: unavailable", + ] + ) + else: + residuals = np.asarray(evaluation.residuals, dtype=float) + rms_residual = float(np.sqrt(np.mean(residuals**2))) + mean_abs_residual = float(np.mean(np.abs(residuals))) + lines[4:4] = [ + f"Residual RMS: {rms_residual:.6g}", + f"Mean |residual|: {mean_abs_residual:.6g}", + ] if fit_result is not None: lines.extend( [ @@ -5496,6 +7979,12 @@ def apply_recommended_scale_settings(self) -> None: "Build a project and load its SAXS components first.", ) return + if not self.prefit_workflow.can_run_prefit(): + self._show_error( + "Scale recommendation failed", + "Disable Model Only Mode and load experimental SAXS data before applying autoscale.", + ) + return try: entries = self.prefit_tab.parameter_entries() recommendation = self.prefit_workflow.recommend_scale_settings( @@ -5508,15 +7997,46 @@ def apply_recommended_scale_settings(self) -> None: maximum=recommendation.recommended_maximum, vary=True, ) - self.prefit_tab.append_log( - "Applied autoscale settings.\n" - f"Current scale: {recommendation.current_scale:.6g}\n" - f"Recommended scale: {recommendation.recommended_scale:.6g}\n" - f"Scale min: {recommendation.recommended_minimum:.6g}\n" - f"Scale max: {recommendation.recommended_maximum:.6g}\n" - f"Adjustment factor: {recommendation.adjustment_factor:.6g}\n" - f"Points used: {recommendation.points_used}" + if recommendation.recommended_offset is not None: + offset_kwargs: dict[str, object] = { + "value": recommendation.recommended_offset, + } + if recommendation.recommended_offset_minimum is not None: + offset_kwargs["minimum"] = ( + recommendation.recommended_offset_minimum + ) + if recommendation.recommended_offset_maximum is not None: + offset_kwargs["maximum"] = ( + recommendation.recommended_offset_maximum + ) + self.prefit_tab.set_parameter_row( + "offset", + **offset_kwargs, + ) + self.prefit_workflow.parameter_entries = ( + self.prefit_tab.parameter_entries() + ) + offset_text = ( + "" + if recommendation.recommended_offset is None + else ( + f"Current offset: " + f"{(recommendation.current_offset or 0.0):.6g}\n" + f"Recommended offset: " + f"{recommendation.recommended_offset:.6g}\n" + ) ) + message = ( + "Applied autoscale settings.\n" + + f"Current scale: {recommendation.current_scale:.6g}\n" + + f"Recommended scale: {recommendation.recommended_scale:.6g}\n" + + f"Scale min: {recommendation.recommended_minimum:.6g}\n" + + f"Scale max: {recommendation.recommended_maximum:.6g}\n" + + offset_text + + f"Adjustment factor: {recommendation.adjustment_factor:.6g}\n" + + f"Points used: {recommendation.points_used}" + ) + self.prefit_tab.append_log(message) self.update_prefit_model() self.statusBar().showMessage("Autoscale applied") except Exception as exc: @@ -5526,15 +8046,27 @@ def _append_scale_recommendation_log( self, recommendation: PrefitScaleRecommendation, ) -> None: - self.prefit_tab.append_log( + offset_text = ( + "" + if recommendation.recommended_offset is None + else ( + f"Current offset: " + f"{(recommendation.current_offset or 0.0):.6g}\n" + f"Recommended offset: " + f"{recommendation.recommended_offset:.6g}\n" + ) + ) + message = ( "Recommended scale estimate available.\n" - f"Current scale: {recommendation.current_scale:.6g}\n" - f"Recommended scale: {recommendation.recommended_scale:.6g}\n" - f"Suggested range: {recommendation.recommended_minimum:.6g} " - f"to {recommendation.recommended_maximum:.6g}\n" - f"Adjustment factor: {recommendation.adjustment_factor:.6g}\n" - f"Points used: {recommendation.points_used}" + + f"Current scale: {recommendation.current_scale:.6g}\n" + + f"Recommended scale: {recommendation.recommended_scale:.6g}\n" + + f"Suggested range: {recommendation.recommended_minimum:.6g} " + + f"to {recommendation.recommended_maximum:.6g}\n" + + offset_text + + f"Adjustment factor: {recommendation.adjustment_factor:.6g}\n" + + f"Points used: {recommendation.points_used}" ) + self.prefit_tab.append_log(message) def _maybe_append_scale_recommendation( self, @@ -5571,9 +8103,20 @@ def launch_saxs_ui( app = QApplication.instance() owns_app = app is None if app is None: + prepare_saxshell_application_identity() app = QApplication([]) - window = SAXSMainWindow(initial_project_dir=initial_project_dir) - window.show() + configure_saxshell_application(app) + splash = create_saxshell_startup_splash() + splash.show() + app.processEvents() + window: SAXSMainWindow | None = None + try: + window = SAXSMainWindow(initial_project_dir=initial_project_dir) + window.show() + splash.finish(window) + except Exception: + splash.close() + raise if owns_app: return int(app.exec()) return 0 diff --git a/src/saxshell/saxs/ui/prefit_tab.py b/src/saxshell/saxs/ui/prefit_tab.py index d73348e..c2cb2a8 100644 --- a/src/saxshell/saxs/ui/prefit_tab.py +++ b/src/saxshell/saxs/ui/prefit_tab.py @@ -8,9 +8,10 @@ NavigationToolbar2QT, ) from matplotlib.figure import Figure -from PySide6.QtCore import QPoint, Qt, Signal -from PySide6.QtGui import QColor +from PySide6.QtCore import QEvent, QPoint, Qt, QTimer, Signal +from PySide6.QtGui import QColor, QTextCursor from PySide6.QtWidgets import ( + QAbstractSpinBox, QCheckBox, QComboBox, QGridLayout, @@ -18,10 +19,12 @@ QHBoxLayout, QHeaderView, QLabel, + QLineEdit, QMessageBox, QProgressBar, QPushButton, QScrollArea, + QScrollBar, QSpinBox, QSplitter, QTableWidget, @@ -37,6 +40,7 @@ ClusterGeometryMetadataRow, PrefitEvaluation, PrefitParameterEntry, + resolve_prefit_parameter_entries, ) from saxshell.saxs.prefit.cluster_geometry import ( DEFAULT_IONIC_RADIUS_TYPE, @@ -109,9 +113,14 @@ def showPopup(self) -> None: class PrefitTab(QWidget): + PARAMETER_VALUE_ROLE = int(Qt.ItemDataRole.UserRole) + PARAMETER_VARY_MEMORY_ROLE = int(Qt.ItemDataRole.UserRole) + 1 + template_changed = Signal(str) show_deprecated_templates_changed = Signal(bool) autosave_toggled = Signal(bool) + field_interaction_requested = Signal() + parameter_reset_requested = Signal(str, str, str) update_model_requested = Signal() run_fit_requested = Signal() apply_recommended_scale_requested = Signal() @@ -184,9 +193,12 @@ class PrefitTab(QWidget): CLUSTER_COL_ANISOTROPY = 9 CLUSTER_COL_MAP_TO = 10 CLUSTER_COL_NOTES = 11 + PARAMETER_SCROLL_RESOLUTION = 2000 + PARAMETER_SCROLL_LOG_DECADE_THRESHOLD = 2.0 def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) + self._console_autoscroll_enabled = True self._current_evaluation: PrefitEvaluation | None = None self._summary_text = "" self._base_log_text = "" @@ -200,11 +212,16 @@ def __init__(self, parent: QWidget | None = None) -> None: ) self._expanded_cluster_geometry_path_rows: set[int] = set() self._expanded_cluster_geometry_note_rows: set[int] = set() + self._model_only_mode = False + self._prefit_execution_enabled = True + self._updating_parameter_table = False + self._updating_parameter_scrollbar = False self._last_cluster_geometry_radii_type = DEFAULT_RADIUS_TYPE self._last_cluster_geometry_ionic_radius_type = ( DEFAULT_IONIC_RADIUS_TYPE ) self._build_ui() + self._install_field_interaction_watchers() def _build_ui(self) -> None: root = QVBoxLayout(self) @@ -272,6 +289,46 @@ def _build_ui(self) -> None: self._scroll_area.setWidget(content) root.addWidget(self._scroll_area) + def _install_field_interaction_watchers(self) -> None: + watched_widgets: list[QWidget] = [] + for child in self.findChildren(QWidget): + if isinstance( + child, + ( + QAbstractSpinBox, + QCheckBox, + QComboBox, + QLineEdit, + ), + ): + watched_widgets.append(child) + watched_widgets.extend( + [ + self.parameter_table, + self.parameter_table.viewport(), + self.cluster_geometry_table, + self.cluster_geometry_table.viewport(), + ] + ) + seen: set[int] = set() + for widget in watched_widgets: + widget_id = id(widget) + if widget_id in seen: + continue + seen.add(widget_id) + widget.installEventFilter(self) + + def eventFilter(self, watched: object, event: QEvent) -> bool: + if isinstance(watched, QWidget) and watched.isEnabled(): + if event.type() in ( + QEvent.Type.KeyPress, + QEvent.Type.MouseButtonDblClick, + QEvent.Type.MouseButtonPress, + QEvent.Type.Wheel, + ): + self.field_interaction_requested.emit() + return super().eventFilter(watched, event) + def _build_controls_group(self) -> QGroupBox: group = QGroupBox("Prefit Controls") layout = QGridLayout(group) @@ -396,6 +453,7 @@ def _build_controls_group(self) -> QGroupBox: ) self.reset_button.clicked.connect(self.reset_requested.emit) run_cell = QWidget() + self._run_button_cell = run_cell run_cell_layout = QVBoxLayout(run_cell) run_cell_layout.setContentsMargins(0, 0, 0, 0) run_cell_layout.setSpacing(4) @@ -403,20 +461,20 @@ def _build_controls_group(self) -> QGroupBox: run_button_row.setContentsMargins(0, 0, 0, 0) run_button_row.addWidget(self.run_button) run_button_row.addWidget(self.prefit_help_button) + run_button_row.addStretch(1) run_cell_layout.addLayout(run_button_row) - run_cell_layout.addWidget(self.autosave_checkbox) - button_grid.addWidget(self.update_button, 0, 0) - button_grid.addWidget(run_cell, 0, 1) + self._prefit_control_button_grid = button_grid + button_grid.addWidget(run_cell, 0, 0) + button_grid.addWidget(self.autosave_checkbox, 0, 1) button_grid.addWidget(self.save_button, 1, 0) button_grid.addWidget(self.reset_button, 1, 1) button_grid.addWidget(self.set_best_button, 2, 0) button_grid.addWidget(self.reset_best_button, 2, 1) layout.addLayout(button_grid, 4, 0, 1, 3) - layout.addWidget(self.recommended_scale_button, 5, 0, 1, 3) return group def _build_solute_volume_fraction_group(self) -> QGroupBox: - group = QGroupBox("Solute Volume Fraction Estimator") + group = QGroupBox("Solution Scattering Estimators") layout = QVBoxLayout(group) header_row = QHBoxLayout() @@ -430,8 +488,9 @@ def _build_solute_volume_fraction_group(self) -> QGroupBox: ) header_row.addWidget(self.solute_volume_fraction_collapse_button) self.solute_volume_fraction_status_label = QLabel( - "This template does not expose a solute or solvent volume-" - "fraction parameter." + "These estimators can calculate solution volume fractions, " + "solvent attenuation scaling, and fluorescence background " + "proxies." ) self.solute_volume_fraction_status_label.setWordWrap(True) header_row.addWidget( @@ -440,9 +499,8 @@ def _build_solute_volume_fraction_group(self) -> QGroupBox: self.solute_volume_fraction_help_button = QToolButton() self.solute_volume_fraction_help_button.setText("?") self.solute_volume_fraction_help_button.setToolTip( - "How the solution-based solute volume fraction estimate is " - "defined. Click for the additive-volume formula and citation " - "link." + "How the solution-scattering estimators are defined. Click for " + "the volume-fraction, attenuation, and fluorescence summary." ) self.solute_volume_fraction_help_button.clicked.connect( self._show_solute_volume_fraction_help @@ -476,6 +534,13 @@ def _build_plot_group(self) -> QGroupBox: self.show_solvent_trace_checkbox.toggled.connect( self._redraw_current_plot ) + self.show_structure_factor_trace_checkbox = QCheckBox( + "Structure factor" + ) + self.show_structure_factor_trace_checkbox.setChecked(False) + self.show_structure_factor_trace_checkbox.toggled.connect( + self._redraw_current_plot + ) self.log_x_checkbox = QCheckBox("Log X") self.log_x_checkbox.setChecked(True) self.log_x_checkbox.toggled.connect(self._redraw_current_plot) @@ -489,6 +554,7 @@ def _build_plot_group(self) -> QGroupBox: controls.addWidget(self.show_experimental_trace_checkbox) controls.addWidget(self.show_model_trace_checkbox) controls.addWidget(self.show_solvent_trace_checkbox) + controls.addWidget(self.show_structure_factor_trace_checkbox) controls.addWidget(self.log_x_checkbox) controls.addWidget(self.log_y_checkbox) controls.addWidget(self.save_plot_data_button) @@ -507,9 +573,84 @@ def _build_plot_group(self) -> QGroupBox: def _build_parameter_group(self) -> QGroupBox: group = QGroupBox("Parameters") layout = QVBoxLayout(group) - self.parameter_table = QTableWidget(0, 7) + action_row = QHBoxLayout() + self._parameter_action_layout = action_row + action_row.addWidget(self.recommended_scale_button) + action_row.addWidget(self.update_button) + self.auto_update_checkbox = QCheckBox( + "Auto-update on parameter change" + ) + self.auto_update_checkbox.setChecked(False) + self.auto_update_checkbox.setToolTip( + "Automatically refresh the SAXS model preview when a parameter " + "value in the table changes." + ) + self.auto_update_checkbox.toggled.connect( + self._on_auto_update_checkbox_toggled + ) + action_row.addWidget(self.auto_update_checkbox) + self.scrollable_parameter_checkbox = QCheckBox("Scrollable parameter") + self.scrollable_parameter_checkbox.setChecked(False) + self.scrollable_parameter_checkbox.setEnabled(False) + self.scrollable_parameter_checkbox.setToolTip( + "Show a scrollbar for the selected parameter and update the " + "model as you scrub through the allowed range." + ) + self.scrollable_parameter_checkbox.toggled.connect( + self._on_scrollable_parameter_toggled + ) + action_row.addWidget(self.scrollable_parameter_checkbox) + action_row.addStretch(1) + layout.addLayout(action_row) + self.parameter_scroll_panel = QWidget() + self.parameter_scroll_panel.setVisible(False) + scroll_panel_layout = QVBoxLayout(self.parameter_scroll_panel) + scroll_panel_layout.setContentsMargins(0, 0, 0, 0) + scroll_panel_layout.setSpacing(4) + info_row = QHBoxLayout() + info_row.setContentsMargins(0, 0, 0, 0) + self.parameter_scroll_name_label = QLabel( + "Select a parameter row to scrub its value." + ) + self.parameter_scroll_name_label.setWordWrap(True) + info_row.addWidget(self.parameter_scroll_name_label, stretch=1) + self.parameter_scroll_mode_label = QLabel("") + info_row.addWidget(self.parameter_scroll_mode_label) + self.parameter_scroll_value_label = QLabel("") + info_row.addWidget(self.parameter_scroll_value_label) + scroll_panel_layout.addLayout(info_row) + self.parameter_scroll_bar = QScrollBar(Qt.Orientation.Horizontal) + self.parameter_scroll_bar.setRange( + 0, + self.PARAMETER_SCROLL_RESOLUTION, + ) + self.parameter_scroll_bar.setSingleStep(1) + self.parameter_scroll_bar.setPageStep( + max(self.PARAMETER_SCROLL_RESOLUTION // 20, 1) + ) + self.parameter_scroll_bar.valueChanged.connect( + self._on_parameter_scrollbar_value_changed + ) + scroll_panel_layout.addWidget(self.parameter_scroll_bar) + layout.addWidget(self.parameter_scroll_panel) + self.parameter_table = QTableWidget(0, 8) self.parameter_table.setHorizontalHeaderLabels( - ["Structure", "Motif", "Param", "Value", "Vary", "Min", "Max"] + [ + "Structure", + "Motif", + "Param", + "Value", + "Vary", + "Min", + "Max", + "Reset", + ] + ) + self.parameter_table.itemChanged.connect( + self._on_parameter_table_item_changed + ) + self.parameter_table.currentCellChanged.connect( + self._on_parameter_table_current_cell_changed ) header = self.parameter_table.horizontalHeader() header.setSectionResizeMode(QHeaderView.ResizeMode.Stretch) @@ -681,6 +822,15 @@ def set_autosave(self, enabled: bool) -> None: self.autosave_checkbox.setChecked(enabled) self.autosave_checkbox.blockSignals(False) + def set_model_only_mode(self, enabled: bool) -> None: + self._model_only_mode = bool(enabled) + self._update_prefit_execution_control_state() + self._update_plot_group_title() + + def set_prefit_execution_enabled(self, enabled: bool) -> None: + self._prefit_execution_enabled = bool(enabled) + self._update_prefit_execution_control_state() + def set_run_config(self, *, method: str, max_nfev: int) -> None: method_index = self.method_combo.findText(method) if method_index >= 0: @@ -703,7 +853,13 @@ def set_saved_states( if self.saved_state_combo.currentIndex() < 0 and state_names: self.saved_state_combo.setCurrentIndex(0) self.saved_state_combo.blockSignals(False) - self.restore_state_button.setEnabled(bool(state_names)) + has_states = bool(state_names) + self.saved_state_combo.setEnabled( + has_states and self._prefit_execution_enabled + ) + self.restore_state_button.setEnabled( + has_states and self._prefit_execution_enabled + ) def selected_saved_state_name(self) -> str | None: text = self.saved_state_combo.currentText().strip() @@ -713,41 +869,62 @@ def populate_parameter_table( self, entries: list[PrefitParameterEntry], ) -> None: - self.parameter_table.setRowCount(len(entries)) - for row, entry in enumerate(entries): - self.parameter_table.setItem( - row, 0, QTableWidgetItem(entry.structure) - ) - self.parameter_table.setItem(row, 1, QTableWidgetItem(entry.motif)) - self.parameter_table.setItem(row, 2, QTableWidgetItem(entry.name)) - self.parameter_table.setItem( - row, - 3, - QTableWidgetItem(f"{entry.value:.6g}"), - ) - vary_item = QTableWidgetItem() - vary_item.setFlags( - Qt.ItemFlag.ItemIsSelectable - | Qt.ItemFlag.ItemIsEnabled - | Qt.ItemFlag.ItemIsUserCheckable - ) - vary_item.setCheckState( - Qt.CheckState.Checked - if entry.vary - else Qt.CheckState.Unchecked - ) - self.parameter_table.setItem(row, 4, vary_item) - self.parameter_table.setItem( - row, - 5, - QTableWidgetItem(f"{entry.minimum:.6g}"), - ) - self.parameter_table.setItem( - row, - 6, - QTableWidgetItem(f"{entry.maximum:.6g}"), - ) + self._updating_parameter_table = True + self.parameter_table.blockSignals(True) + try: + self.parameter_table.setColumnCount(8) + self.parameter_table.setRowCount(len(entries)) + for row, entry in enumerate(entries): + self.parameter_table.setItem( + row, 0, QTableWidgetItem(entry.structure) + ) + self.parameter_table.setItem( + row, 1, QTableWidgetItem(entry.motif) + ) + self.parameter_table.setItem( + row, 2, QTableWidgetItem(entry.name) + ) + vary_item = QTableWidgetItem() + vary_item.setCheckState( + Qt.CheckState.Checked + if entry.vary + else Qt.CheckState.Unchecked + ) + self.parameter_table.setItem(row, 4, vary_item) + self._set_parameter_value_item( + row, + value=float(entry.value), + value_expression=entry.value_expression, + initial_value_expression=entry.initial_value_expression, + ) + self.parameter_table.setItem( + row, + 5, + QTableWidgetItem(f"{entry.minimum:.6g}"), + ) + self.parameter_table.setItem( + row, + 6, + QTableWidgetItem(f"{entry.maximum:.6g}"), + ) + reset_button = QPushButton("Reset") + reset_button.setToolTip( + "Reset this parameter to the template-default prefit " + "value, vary setting, and bounds." + ) + reset_button.clicked.connect( + lambda _checked=False, structure=entry.structure, motif=entry.motif, name=entry.name: self.parameter_reset_requested.emit( + structure, + motif, + name, + ) + ) + self.parameter_table.setCellWidget(row, 7, reset_button) + finally: + self.parameter_table.blockSignals(False) + self._updating_parameter_table = False self.parameter_table.resizeRowsToContents() + self._refresh_parameter_scroll_panel() def set_cluster_geometry_visible(self, visible: bool) -> None: self._cluster_geometry_group.setVisible(bool(visible)) @@ -762,26 +939,41 @@ def set_solute_volume_fraction_target( self, parameter_name: str | None, fraction_kind: str | None, + solvent_weight_parameter: str | None = None, ) -> None: + target_messages: list[str] = [] if parameter_name and fraction_kind: target_label = ( "solute" if str(fraction_kind).strip() == "solute" else "solvent" ) + target_messages.append( + f"{target_label} SAXS-effective interaction fraction -> {parameter_name}" + ) + if solvent_weight_parameter: + if parameter_name and fraction_kind: + target_messages.append( + f"attenuation solvent scale -> {solvent_weight_parameter}" + ) + else: + target_messages.append( + "combined solvent background multiplier -> " + f"{solvent_weight_parameter}" + ) + if target_messages: self.solute_volume_fraction_status_label.setText( - "Estimate the physical " - f"{target_label} volume fraction and apply it to " - f"{parameter_name}." + "Automatic Prefit targets: " + "; ".join(target_messages) + "." ) else: self.solute_volume_fraction_status_label.setText( - "This template does not expose a solute or solvent " - "volume-fraction parameter." + "These estimators are available for diagnostics, but the " + "active template does not expose an automatic Prefit target." ) self.solute_volume_fraction_widget.set_target_parameter( parameter_name, fraction_kind, + solvent_weight_parameter, ) def solute_volume_fraction_is_collapsed(self) -> bool: @@ -981,16 +1173,35 @@ def set_cluster_geometry_active_ionic_radius_type( def parameter_entries(self) -> list[PrefitParameterEntry]: entries: list[PrefitParameterEntry] = [] for row in range(self.parameter_table.rowCount()): + value_text = self._item_text(row, 3) + value_expression: str | None = None + initial_value_expression: str | None = None + vary = ( + self.parameter_table.item(row, 4).checkState() + == Qt.CheckState.Checked + ) + try: + value = float(value_text) + except (TypeError, ValueError): + if not value_text: + raise ValueError( + "Each prefit parameter requires a numeric value or " + "a linked-parameter expression." + ) + if vary: + initial_value_expression = value_text + else: + value_expression = value_text + value = self._parameter_item_numeric_value( + self.parameter_table.item(row, 3) + ) entries.append( PrefitParameterEntry( structure=self._item_text(row, 0), motif=self._item_text(row, 1), name=self._item_text(row, 2), - value=float(self._item_text(row, 3)), - vary=( - self.parameter_table.item(row, 4).checkState() - == Qt.CheckState.Checked - ), + value=value, + vary=vary, minimum=float(self._item_text(row, 5)), maximum=float(self._item_text(row, 6)), category=( @@ -998,9 +1209,25 @@ def parameter_entries(self) -> list[PrefitParameterEntry]: if self._item_text(row, 2).startswith("w") else "fit" ), + value_expression=value_expression, + initial_value_expression=initial_value_expression, ) ) - return entries + resolved_entries = resolve_prefit_parameter_entries(entries) + self._updating_parameter_table = True + self.parameter_table.blockSignals(True) + try: + for row, entry in enumerate(resolved_entries): + value_item = self.parameter_table.item(row, 3) + if value_item is not None: + value_item.setData( + self.PARAMETER_VALUE_ROLE, + float(entry.value), + ) + finally: + self.parameter_table.blockSignals(False) + self._updating_parameter_table = False + return resolved_entries def find_parameter_row(self, parameter_name: str) -> int: for row in range(self.parameter_table.rowCount()): @@ -1008,42 +1235,351 @@ def find_parameter_row(self, parameter_name: str) -> int: return row return -1 + def find_parameter_row_by_signature( + self, + structure: str, + motif: str, + parameter_name: str, + ) -> int: + for row in range(self.parameter_table.rowCount()): + if ( + self._item_text(row, 0) == structure + and self._item_text(row, 1) == motif + and self._item_text(row, 2) == parameter_name + ): + return row + return -1 + def set_parameter_row( self, parameter_name: str, *, + structure: str | None = None, + motif: str | None = None, value: float | None = None, minimum: float | None = None, maximum: float | None = None, vary: bool | None = None, ) -> None: - row = self.find_parameter_row(parameter_name) + row = ( + self.find_parameter_row_by_signature( + structure, + motif, + parameter_name, + ) + if structure is not None and motif is not None + else self.find_parameter_row(parameter_name) + ) if row < 0: raise ValueError(f"Parameter {parameter_name} was not found.") - if value is not None: - self.parameter_table.setItem( - row, - 3, - QTableWidgetItem(f"{float(value):.6g}"), + self._updating_parameter_table = True + self.parameter_table.blockSignals(True) + try: + if value is not None: + self._set_parameter_value_item( + row, + value=float(value), + value_expression=None, + initial_value_expression=None, + ) + if vary is not None: + vary_item = self.parameter_table.item(row, 4) + if vary_item is not None: + vary_item.setCheckState( + Qt.CheckState.Checked + if vary + else Qt.CheckState.Unchecked + ) + if minimum is not None: + self.parameter_table.setItem( + row, + 5, + QTableWidgetItem(f"{float(minimum):.6g}"), + ) + if maximum is not None: + self.parameter_table.setItem( + row, + 6, + QTableWidgetItem(f"{float(maximum):.6g}"), + ) + finally: + self.parameter_table.blockSignals(False) + self._updating_parameter_table = False + self._refresh_parameter_scroll_panel() + + def auto_update_on_parameter_change(self) -> bool: + return bool(self.auto_update_checkbox.isChecked()) + + def scrollable_parameter_enabled(self) -> bool: + return bool(self.scrollable_parameter_checkbox.isChecked()) + + def _on_parameter_table_item_changed( + self, + item: QTableWidgetItem, + ) -> None: + if item.column() in {3, 4}: + self._sync_parameter_row_link_state(item.row()) + if not self._updating_parameter_table and item.column() in {3, 4}: + try: + self.parameter_entries() + except Exception: + pass + self._refresh_parameter_scroll_panel() + if self._updating_parameter_table or item.column() != 3: + return + if not self.auto_update_on_parameter_change(): + return + try: + self.parameter_entries() + except Exception: + return + self.update_model_requested.emit() + + def _on_auto_update_checkbox_toggled(self, enabled: bool) -> None: + self.scrollable_parameter_checkbox.setEnabled(bool(enabled)) + if enabled: + self._refresh_parameter_scroll_panel() + return + self.scrollable_parameter_checkbox.blockSignals(True) + self.scrollable_parameter_checkbox.setChecked(False) + self.scrollable_parameter_checkbox.blockSignals(False) + self._refresh_parameter_scroll_panel() + + def _on_scrollable_parameter_toggled(self, enabled: bool) -> None: + if enabled and not self.auto_update_on_parameter_change(): + self.scrollable_parameter_checkbox.blockSignals(True) + self.scrollable_parameter_checkbox.setChecked(False) + self.scrollable_parameter_checkbox.blockSignals(False) + self._refresh_parameter_scroll_panel() + + def _on_parameter_table_current_cell_changed( + self, + current_row: int, + current_column: int, + previous_row: int, + previous_column: int, + ) -> None: + del current_column, previous_row, previous_column + if current_row < 0: + return + self._refresh_parameter_scroll_panel() + + def _refresh_parameter_scroll_panel(self) -> None: + if ( + not self.auto_update_on_parameter_change() + or not self.scrollable_parameter_enabled() + ): + self.parameter_scroll_panel.setVisible(False) + return + row = self.parameter_table.currentRow() + if row < 0: + self.parameter_scroll_panel.setVisible(False) + return + parameter_name = self._item_text(row, 2) + value_item = self.parameter_table.item(row, 3) + value_text = self._item_text(row, 3) + uses_expression = self._parameter_value_uses_expression(value_text) + try: + value = ( + self._parameter_item_numeric_value(value_item) + if uses_expression + else float(value_text) ) - if vary is not None: + minimum = float(self._item_text(row, 5)) + maximum = float(self._item_text(row, 6)) + except (TypeError, ValueError): + self.parameter_scroll_name_label.setText( + f"{parameter_name or 'Selected parameter'} has no numeric range." + ) + self.parameter_scroll_mode_label.setText("") + self.parameter_scroll_value_label.setText("") + self.parameter_scroll_bar.setEnabled(False) + self.parameter_scroll_panel.setVisible(True) + return + lower = min(minimum, maximum) + upper = max(minimum, maximum) + if not np.isfinite(lower) or not np.isfinite(upper) or lower == upper: + self.parameter_scroll_name_label.setText( + f"{parameter_name or 'Selected parameter'} has no usable range." + ) + self.parameter_scroll_mode_label.setText("") + self.parameter_scroll_value_label.setText(f"Value {value:.6g}") + self.parameter_scroll_bar.setEnabled(False) + self.parameter_scroll_panel.setVisible(True) + return + scroll_mode = self._parameter_scroll_mode(lower, upper) + self.parameter_scroll_name_label.setText( + f"{parameter_name} [{lower:.6g}, {upper:.6g}]" + ) + if uses_expression: vary_item = self.parameter_table.item(row, 4) - if vary_item is not None: - vary_item.setCheckState( - Qt.CheckState.Checked if vary else Qt.CheckState.Unchecked - ) - if minimum is not None: - self.parameter_table.setItem( - row, - 5, - QTableWidgetItem(f"{float(minimum):.6g}"), + expression_mode = ( + "Initial expression seed" + if vary_item is not None + and vary_item.checkState() == Qt.CheckState.Checked + else "Dependent expression" ) - if maximum is not None: - self.parameter_table.setItem( - row, - 6, - QTableWidgetItem(f"{float(maximum):.6g}"), + self.parameter_scroll_mode_label.setText( + f"{scroll_mode.capitalize()} scroll | {expression_mode}" + ) + else: + self.parameter_scroll_mode_label.setText( + f"{scroll_mode.capitalize()} scroll" ) + self.parameter_scroll_value_label.setText(f"Value {value:.6g}") + position = self._parameter_scroll_position_for_value( + value, + lower, + upper, + scroll_mode, + ) + self._updating_parameter_scrollbar = True + self.parameter_scroll_bar.blockSignals(True) + try: + self.parameter_scroll_bar.setEnabled(True) + self.parameter_scroll_bar.setValue(position) + finally: + self.parameter_scroll_bar.blockSignals(False) + self._updating_parameter_scrollbar = False + self.parameter_scroll_panel.setVisible(True) + + def _on_parameter_scrollbar_value_changed(self, position: int) -> None: + if self._updating_parameter_scrollbar: + return + row = self.parameter_table.currentRow() + if row < 0: + return + try: + minimum = float(self._item_text(row, 5)) + maximum = float(self._item_text(row, 6)) + except (TypeError, ValueError): + return + lower = min(minimum, maximum) + upper = max(minimum, maximum) + if not np.isfinite(lower) or not np.isfinite(upper) or lower == upper: + return + scroll_mode = self._parameter_scroll_mode(lower, upper) + value = self._parameter_scroll_value_for_position( + position, + lower, + upper, + scroll_mode, + ) + item = self.parameter_table.item(row, 3) + if item is None: + item = QTableWidgetItem() + self.parameter_table.setItem(row, 3, item) + formatted = f"{float(value):.6g}" + self.parameter_scroll_value_label.setText(f"Value {formatted}") + if item.text().strip() == formatted: + return + item.setText(formatted) + + def _parameter_scroll_mode( + self, + minimum: float, + maximum: float, + ) -> str: + if minimum == 0.0 or maximum == 0.0 or minimum * maximum < 0.0: + return "linear" + decade_span = abs( + np.log10(abs(float(maximum))) - np.log10(abs(float(minimum))) + ) + return ( + "log" + if decade_span >= self.PARAMETER_SCROLL_LOG_DECADE_THRESHOLD + else "linear" + ) + + def _parameter_scroll_position_for_value( + self, + value: float, + minimum: float, + maximum: float, + scroll_mode: str, + ) -> int: + clamped_value = min(max(float(value), minimum), maximum) + lower = self._parameter_scroll_transform( + minimum, + minimum, + maximum, + scroll_mode, + ) + upper = self._parameter_scroll_transform( + maximum, + minimum, + maximum, + scroll_mode, + ) + transformed_value = self._parameter_scroll_transform( + clamped_value, + minimum, + maximum, + scroll_mode, + ) + if np.isclose(upper, lower): + return 0 + fraction = (transformed_value - lower) / (upper - lower) + fraction = float(np.clip(fraction, 0.0, 1.0)) + return int(round(fraction * self.PARAMETER_SCROLL_RESOLUTION)) + + def _parameter_scroll_value_for_position( + self, + position: int, + minimum: float, + maximum: float, + scroll_mode: str, + ) -> float: + lower = self._parameter_scroll_transform( + minimum, + minimum, + maximum, + scroll_mode, + ) + upper = self._parameter_scroll_transform( + maximum, + minimum, + maximum, + scroll_mode, + ) + fraction = float(position) / float(self.PARAMETER_SCROLL_RESOLUTION) + fraction = float(np.clip(fraction, 0.0, 1.0)) + transformed_value = lower + fraction * (upper - lower) + value = self._parameter_scroll_inverse_transform( + transformed_value, + minimum, + maximum, + scroll_mode, + ) + return float(np.clip(value, minimum, maximum)) + + @staticmethod + def _parameter_scroll_transform( + value: float, + minimum: float, + maximum: float, + scroll_mode: str, + ) -> float: + numeric_value = float(value) + if scroll_mode != "log": + return numeric_value + if minimum < 0.0 and maximum < 0.0: + return -float(np.log10(-numeric_value)) + return float(np.log10(numeric_value)) + + @staticmethod + def _parameter_scroll_inverse_transform( + value: float, + minimum: float, + maximum: float, + scroll_mode: str, + ) -> float: + numeric_value = float(value) + if scroll_mode != "log": + return numeric_value + if minimum < 0.0 and maximum < 0.0: + return -(10.0 ** (-numeric_value)) + return 10.0**numeric_value def run_config(self) -> PrefitRunConfig: return PrefitRunConfig( @@ -1062,6 +1598,7 @@ def plot_evaluation( axis.set_xscale("linear") self.figure.clear() self._update_prefit_trace_toggle_state(evaluation) + self._update_plot_group_title() if evaluation is None: axis = self.figure.add_subplot(111) axis.text( @@ -1075,13 +1612,23 @@ def plot_evaluation( self.canvas.draw() return - grid = self.figure.add_gridspec(2, 1, height_ratios=[3, 1]) - top = self.figure.add_subplot(grid[0, 0]) - bottom = self.figure.add_subplot(grid[1, 0], sharex=top) + has_experimental = evaluation.experimental_intensities is not None + has_residuals = evaluation.residuals is not None + if has_experimental and has_residuals: + grid = self.figure.add_gridspec(2, 1, height_ratios=[3, 1]) + top = self.figure.add_subplot(grid[0, 0]) + bottom = self.figure.add_subplot(grid[1, 0], sharex=top) + else: + top = self.figure.add_subplot(111) + bottom = None plotted_lines = [] + structure_axis = None - if self.show_experimental_trace_checkbox.isChecked(): + if ( + has_experimental + and self.show_experimental_trace_checkbox.isChecked() + ): (experimental_line,) = top.plot( evaluation.q_values, evaluation.experimental_intensities, @@ -1111,6 +1658,35 @@ def plot_evaluation( ) plotted_lines.append(solvent_line) + if ( + self.show_structure_factor_trace_checkbox.isChecked() + and evaluation.structure_factor_trace is not None + ): + structure_values = np.asarray( + evaluation.structure_factor_trace, + dtype=float, + ) + structure_mask = np.isfinite(structure_values) + if np.any(structure_mask): + structure_axis = top.twinx() + structure_axis.set_xscale( + "log" if self.log_x_checkbox.isChecked() else "linear" + ) + (structure_line,) = structure_axis.plot( + np.asarray(evaluation.q_values, dtype=float)[ + structure_mask + ], + structure_values[structure_mask], + color="tab:purple", + linestyle="--", + linewidth=1.5, + label="Structure factor S(q)", + ) + structure_axis.set_ylabel("S(q)", color="tab:purple") + structure_axis.tick_params(axis="y", colors="tab:purple") + structure_axis.spines["right"].set_color("tab:purple") + plotted_lines.append(structure_line) + if self.show_model_trace_checkbox.isChecked(): (model_line,) = top.plot( evaluation.q_values, @@ -1140,17 +1716,20 @@ def plot_evaluation( if plotted_lines: self._build_interactive_legend(top, plotted_lines) - bottom.axhline(0.0, color="0.5", linewidth=1.0) - bottom.plot( - evaluation.q_values, - evaluation.residuals, - color="tab:blue", - ) - bottom.set_xscale( - "log" if self.log_x_checkbox.isChecked() else "linear" - ) - bottom.set_xlabel("q (Å⁻¹)") - bottom.set_ylabel("Residual") + if bottom is not None and evaluation.residuals is not None: + bottom.axhline(0.0, color="0.5", linewidth=1.0) + bottom.plot( + evaluation.q_values, + evaluation.residuals, + color="tab:blue", + ) + bottom.set_xscale( + "log" if self.log_x_checkbox.isChecked() else "linear" + ) + bottom.set_xlabel("q (Å⁻¹)") + bottom.set_ylabel("Residual") + else: + top.set_xlabel("q (Å⁻¹)") self.figure.tight_layout() self.canvas.draw() @@ -1161,6 +1740,14 @@ def current_evaluation(self) -> PrefitEvaluation | None: def _prefit_metric_lines( evaluation: PrefitEvaluation, ) -> list[str]: + if ( + evaluation.experimental_intensities is None + or evaluation.residuals is None + ): + return [ + "Model Only Mode", + "Experimental fit metrics unavailable", + ] experimental_values = np.asarray( evaluation.experimental_intensities, dtype=float, @@ -1179,11 +1766,19 @@ def _prefit_metric_lines( if total_sum_squares > 0.0 else 1.0 ) - return [ + metric_lines = [ f"RMSE: {rmse:.4g}", f"Mean |res|: {mean_abs_residual:.4g}", f"R²: {r_squared:.4g}", ] + non_positive_model_points = int( + np.count_nonzero(np.isfinite(model_values) & (model_values <= 0.0)) + ) + if non_positive_model_points: + metric_lines.append( + f"Model <= 0 at {non_positive_model_points} q-points" + ) + return metric_lines def append_log(self, message: str) -> None: stripped = message.strip() @@ -1199,24 +1794,161 @@ def set_summary_text(self, text: str) -> None: self._summary_text = text.strip() self._render_output() + def set_console_autoscroll_enabled(self, enabled: bool) -> None: + self._console_autoscroll_enabled = bool(enabled) + if self._console_autoscroll_enabled: + self._scroll_output_to_end() + def _item_text(self, row: int, column: int) -> str: item = self.parameter_table.item(row, column) return item.text().strip() if item is not None else "" + def _parameter_item_numeric_value( + self, + item: QTableWidgetItem | None, + ) -> float: + if item is None: + return 0.0 + raw_value = item.data(self.PARAMETER_VALUE_ROLE) + try: + return float(raw_value) + except (TypeError, ValueError): + return 0.0 + + def _parameter_value_uses_expression(self, value_text: str) -> bool: + stripped = value_text.strip() + if not stripped: + return False + try: + float(stripped) + except (TypeError, ValueError): + return True + return False + + def _set_parameter_value_item( + self, + row: int, + *, + value: float, + value_expression: str | None, + initial_value_expression: str | None, + ) -> None: + display_text = ( + value_expression.strip() + if value_expression is not None and value_expression.strip() + else ( + initial_value_expression.strip() + if initial_value_expression is not None + and initial_value_expression.strip() + else f"{float(value):.6g}" + ) + ) + value_item = QTableWidgetItem(display_text) + value_item.setData(self.PARAMETER_VALUE_ROLE, float(value)) + self.parameter_table.setItem(row, 3, value_item) + self._sync_parameter_row_link_state(row) + + def _sync_parameter_row_link_state(self, row: int) -> None: + if row < 0: + return + vary_item = self.parameter_table.item(row, 4) + value_item = self.parameter_table.item(row, 3) + if vary_item is None or value_item is None: + return + linked = self._parameter_value_uses_expression(self._item_text(row, 3)) + resolved_value = self._parameter_item_numeric_value(value_item) + was_updating = self._updating_parameter_table + self._updating_parameter_table = True + self.parameter_table.blockSignals(True) + try: + vary_item.setFlags( + Qt.ItemFlag.ItemIsSelectable + | Qt.ItemFlag.ItemIsEnabled + | Qt.ItemFlag.ItemIsUserCheckable + ) + if not linked: + vary_item.setToolTip( + "Enable to vary this parameter during fitting." + ) + value_item.setToolTip("") + return + if vary_item.checkState() == Qt.CheckState.Checked: + vary_item.setToolTip( + "Artemis-style guess behavior: with Vary enabled, the " + "Value expression is evaluated into the starting numeric " + "value and the parameter may still refine within Min/Max." + ) + value_item.setToolTip( + "Initial expression seed. With Vary enabled, the Value " + "expression resolves to the current numeric value " + f"({resolved_value:.6g}) before fitting, but the " + "parameter then varies independently." + ) + return + vary_item.setToolTip( + "Artemis-style def behavior: with Vary disabled, the Value " + "expression is treated as a live dependent parameter and its " + "Min/Max are ignored during fitting." + ) + value_item.setToolTip( + "Dependent expression. With Vary disabled, this parameter " + "follows the expression entered in the Value column. Current " + f"resolved value: {resolved_value:.6g}." + ) + finally: + self.parameter_table.blockSignals(False) + self._updating_parameter_table = was_updating + def _redraw_current_plot(self) -> None: self.plot_evaluation(self._current_evaluation) + def _update_prefit_execution_control_state(self) -> None: + enabled = bool(self._prefit_execution_enabled) + self.method_combo.setEnabled(enabled) + self.nfev_spin.setEnabled(enabled) + self.saved_state_combo.setEnabled( + enabled and self.saved_state_combo.count() > 0 + ) + self.restore_state_button.setEnabled( + enabled and self.saved_state_combo.count() > 0 + ) + self.run_button.setEnabled(enabled) + self.recommended_scale_button.setEnabled(enabled) + self.autosave_checkbox.setEnabled(enabled) + self.save_button.setEnabled(enabled) + + def _update_plot_group_title(self) -> None: + has_experimental = ( + self._current_evaluation is not None + and self._current_evaluation.experimental_intensities is not None + ) + if self._model_only_mode or not has_experimental: + self._plot_group.setTitle("Model Preview") + return + self._plot_group.setTitle("Model vs Experimental") + def _update_prefit_trace_toggle_state( self, evaluation: PrefitEvaluation | None, ) -> None: has_evaluation = evaluation is not None + has_experimental = ( + has_evaluation and evaluation.experimental_intensities is not None + ) has_solvent = ( has_evaluation and evaluation.solvent_contribution is not None ) - self.show_experimental_trace_checkbox.setEnabled(has_evaluation) + has_structure_factor = ( + has_evaluation and evaluation.structure_factor_trace is not None + ) + self.show_experimental_trace_checkbox.setEnabled( + bool(has_experimental) + ) self.show_model_trace_checkbox.setEnabled(has_evaluation) self.show_solvent_trace_checkbox.setEnabled(bool(has_solvent)) + self.show_structure_factor_trace_checkbox.setEnabled( + bool(has_structure_factor) + ) def _on_cluster_geometry_mapping_changed( self, @@ -1794,7 +2526,7 @@ def reset_cluster_geometry_progress(self) -> None: self.cluster_geometry_progress_bar.setFormat("%v / %m files") def _build_interactive_legend(self, axis, lines: list[object]) -> None: - legend = axis.legend() + legend = axis.legend(handles=lines, loc="best") if legend is None: return legend_handles = getattr(legend, "legend_handles", None) @@ -1847,6 +2579,7 @@ def set_selected_template( self._update_template_tooltip() def _render_output(self, *, scroll_to_end: bool = False) -> None: + del scroll_to_end sections: list[str] = [] if self._summary_text: sections.append("Prefit Summary\n" + self._summary_text) @@ -1857,10 +2590,39 @@ def _render_output(self, *, scroll_to_end: bool = False) -> None: ] if history_parts: sections.append("Prefit Console\n" + "\n\n".join(history_parts)) + scrollbar = self.output_box.verticalScrollBar() + previous_value = scrollbar.value() + previous_maximum = max(scrollbar.maximum(), 1) self.output_box.setPlainText("\n\n".join(sections).strip()) - if scroll_to_end: - scrollbar = self.output_box.verticalScrollBar() - scrollbar.setValue(scrollbar.maximum()) + if self._console_autoscroll_enabled: + self._scroll_output_to_end() + return + updated_scrollbar = self.output_box.verticalScrollBar() + if updated_scrollbar.maximum() > 0: + position_fraction = previous_value / previous_maximum + updated_scrollbar.setValue( + int(round(position_fraction * updated_scrollbar.maximum())) + ) + + def _scroll_output_to_end(self) -> None: + cursor = self.output_box.textCursor() + cursor.movePosition(QTextCursor.MoveOperation.End) + self.output_box.setTextCursor(cursor) + self.output_box.ensureCursorVisible() + scrollbar = self.output_box.verticalScrollBar() + scrollbar.setValue(scrollbar.maximum()) + QTimer.singleShot( + 0, + self._scroll_output_to_end_once, + ) + + def _scroll_output_to_end_once(self) -> None: + cursor = self.output_box.textCursor() + cursor.movePosition(QTextCursor.MoveOperation.End) + self.output_box.setTextCursor(cursor) + self.output_box.ensureCursorVisible() + scrollbar = self.output_box.verticalScrollBar() + scrollbar.setValue(scrollbar.maximum()) def _show_ionic_radius_help(self) -> None: QMessageBox.information( @@ -1872,7 +2634,7 @@ def _show_ionic_radius_help(self) -> None: def _show_solute_volume_fraction_help(self) -> None: QMessageBox.information( self, - "Solute Volume Fraction Estimate Help", + "Solution Scattering Estimate Help", self.SOLUTE_VOLUME_FRACTION_HELP_TEXT, ) diff --git a/src/saxshell/saxs/ui/project_setup_tab.py b/src/saxshell/saxs/ui/project_setup_tab.py index 30dcdfd..776439c 100644 --- a/src/saxshell/saxs/ui/project_setup_tab.py +++ b/src/saxshell/saxs/ui/project_setup_tab.py @@ -12,8 +12,8 @@ ) from matplotlib.colors import to_hex from matplotlib.figure import Figure -from PySide6.QtCore import Qt, Signal -from PySide6.QtGui import QColor +from PySide6.QtCore import Qt, QTimer, Signal +from PySide6.QtGui import QColor, QTextCursor from PySide6.QtWidgets import ( QAbstractItemView, QCheckBox, @@ -85,9 +85,12 @@ class ProjectSetupTab(QWidget): save_component_plot_data_requested = Signal() save_prior_plot_data_requested = Signal() show_deprecated_templates_changed = Signal(bool) + model_only_mode_changed = Signal(bool) def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) + self._console_autoscroll_enabled = True + self._summary_text = "" self._experimental_header_rows = 0 self._experimental_q_column: int | None = None self._experimental_intensity_column: int | None = None @@ -232,6 +235,17 @@ def _build_inputs_group(self) -> QGroupBox: ) layout.addRow("Clusters folder", self._clusters_row()) + self.model_only_mode_checkbox = QCheckBox("Model Only Mode") + self.model_only_mode_checkbox.setChecked(False) + self.model_only_mode_checkbox.setToolTip( + "Disable experimental-data-dependent fitting and use the project " + "as a forward-model-only SAXS simulator." + ) + self.model_only_mode_checkbox.toggled.connect( + self._on_model_only_mode_toggled + ) + layout.addRow("", self.model_only_mode_checkbox) + self.experimental_data_edit = QLineEdit() self.experimental_data_edit.setReadOnly(True) layout.addRow("Experimental data", self._experimental_data_row()) @@ -325,11 +339,13 @@ def _build_inputs_group(self) -> QGroupBox: def _build_model_group(self) -> QGroupBox: group = QGroupBox("Model and Build") - layout = QHBoxLayout(group) + layout = QVBoxLayout(group) - controls_widget = QWidget() - controls_layout = QFormLayout(controls_widget) + header_widget = QWidget() + self._model_build_header_widget = header_widget + header_layout = QFormLayout(header_widget) self.template_combo = QComboBox() + self.template_combo.setMinimumWidth(420) self.template_combo.currentIndexChanged.connect( self._on_template_combo_changed ) @@ -348,15 +364,24 @@ def _build_model_group(self) -> QGroupBox: self.show_deprecated_templates_checkbox.toggled.connect( self.show_deprecated_templates_changed.emit ) - controls_layout.addRow( + header_layout.addRow( "Selected template", self._template_row(), ) self.active_template_edit = QLineEdit() self.active_template_edit.setReadOnly(True) - controls_layout.addRow("Active template", self.active_template_edit) - + self.active_template_edit.setMinimumWidth(420) + header_layout.addRow("Active template", self.active_template_edit) + layout.addWidget(header_widget) + + lower_layout = QHBoxLayout() + self._model_build_lower_layout = lower_layout + button_widget = QWidget() + self._model_build_button_widget = button_widget + button_layout = QVBoxLayout(button_widget) + button_layout.setContentsMargins(0, 0, 0, 0) + button_layout.setSpacing(8) self.build_components_button = QPushButton("Build SAXS Components") self.build_components_button.clicked.connect( self.build_components_requested.emit @@ -369,12 +394,19 @@ def _build_model_group(self) -> QGroupBox: self.install_model_button.clicked.connect( self.install_model_requested.emit ) - controls_layout.addRow("", self.build_prior_weights_button) - controls_layout.addRow("", self.build_components_button) - controls_layout.addRow("", self.install_model_button) - layout.addWidget(controls_widget, stretch=4) + button_layout.addWidget(self.build_prior_weights_button) + button_layout.addWidget(self.build_components_button) + button_layout.addWidget(self.install_model_button) + button_layout.addStretch(1) + button_widget.setSizePolicy( + QSizePolicy.Policy.Fixed, + QSizePolicy.Policy.Preferred, + ) + button_widget.setMinimumWidth(220) + lower_layout.addWidget(button_widget, stretch=0) clusters_group = QGroupBox("Recognized Clusters") + self._recognized_clusters_group = clusters_group clusters_layout = QVBoxLayout(clusters_group) self.recognized_clusters_table = QTableWidget(0, 8) self.recognized_clusters_table.setHorizontalHeaderLabels( @@ -425,7 +457,8 @@ def _build_model_group(self) -> QGroupBox: header.resizeSection(6, 72) header.resizeSection(7, 92) clusters_layout.addWidget(self.recognized_clusters_table) - layout.addWidget(clusters_group, stretch=6) + lower_layout.addWidget(clusters_group, stretch=1) + layout.addLayout(lower_layout, stretch=1) return group def _template_row(self) -> QWidget: @@ -657,17 +690,10 @@ def set_project_selected(self, selected: bool) -> None: self.save_project_button.setEnabled(selected) self._update_prior_control_state() if not selected: - self.data_status_label.setText( - "Choose an experimental SAXS file or folder after opening a " - "project.\n" - "The selected file, columns, q-range, and import settings " - "will be summarized here." - ) - self.solvent_status_label.setText( - "Optional solvent SAXS data can be loaded here and will be " - "carried into prefit and DREAM if the active model uses " - "solvent intensities." - ) + self._experimental_summary = None + self._solvent_summary = None + self._refresh_data_status_labels() + self._apply_model_only_mode_state() def set_project_settings( self, @@ -709,6 +735,7 @@ def set_project_settings( self._solvent_q_column = settings.solvent_q_column self._solvent_intensity_column = settings.solvent_intensity_column self._solvent_error_column = settings.solvent_error_column + self.set_model_only_mode(settings.model_only_mode) self.qmin_edit.setText( "" if settings.q_min is None else f"{settings.q_min:g}" ) @@ -808,6 +835,8 @@ def set_project_settings( ) else: self._solvent_summary = None + self._refresh_data_status_labels() + self._apply_model_only_mode_state() self._update_data_trace_control_state() self._redraw_saxs_preview() @@ -879,6 +908,17 @@ def project_name(self) -> str | None: text = self.project_name_edit.text().strip() return text or None + def model_only_mode(self) -> bool: + return bool(self.model_only_mode_checkbox.isChecked()) + + def set_model_only_mode(self, enabled: bool) -> None: + self.model_only_mode_checkbox.blockSignals(True) + self.model_only_mode_checkbox.setChecked(bool(enabled)) + self.model_only_mode_checkbox.blockSignals(False) + if enabled and self.use_experimental_grid_checkbox.isChecked(): + self.use_experimental_grid_checkbox.setChecked(False) + self._apply_model_only_mode_state() + def open_project_dir(self) -> Path | None: text = self.open_project_dir_edit.text().strip() return Path(text).expanduser() if text else None @@ -895,6 +935,70 @@ def solvent_data_path(self) -> Path | None: text = self.solvent_data_edit.text().strip() return Path(text).expanduser() if text else None + def _default_experimental_status_text(self) -> str: + return ( + "Choose an experimental SAXS file or folder after opening a " + "project.\n" + "The selected file, columns, q-range, and import settings will " + "be summarized here." + ) + + def _default_solvent_status_text(self) -> str: + return ( + "Optional solvent SAXS data can be loaded here and will be " + "carried into prefit and DREAM if the active model uses " + "solvent intensities." + ) + + def _refresh_data_status_labels(self) -> None: + if self.model_only_mode(): + self.data_status_label.setText( + "Model Only Mode is enabled.\n" + "Experimental data input is locked and hidden from the " + "plots until Model Only Mode is turned off." + ) + self.solvent_status_label.setText( + "Model Only Mode is enabled.\n" + "Solvent data input is locked and hidden from the plots " + "until Model Only Mode is turned off." + ) + return + if self._experimental_summary is not None: + self.data_status_label.setText( + self._experimental_import_summary(self._experimental_summary) + ) + else: + self.data_status_label.setText( + self._default_experimental_status_text() + ) + if self._solvent_summary is not None: + self.solvent_status_label.setText( + self._experimental_import_summary(self._solvent_summary) + ) + else: + self.solvent_status_label.setText( + self._default_solvent_status_text() + ) + + def _apply_model_only_mode_state(self) -> None: + locked = self.model_only_mode() + self.experimental_data_edit.setEnabled(not locked) + self.experimental_file_button.setEnabled(not locked) + self.experimental_folder_button.setEnabled(not locked) + self.experimental_columns_button.setEnabled(not locked) + self.experimental_clear_button.setEnabled(not locked) + self.solvent_data_edit.setEnabled(not locked) + self.solvent_file_button.setEnabled(not locked) + self.solvent_columns_button.setEnabled(not locked) + self.solvent_clear_button.setEnabled(not locked) + self.use_experimental_grid_checkbox.setEnabled(not locked) + if locked and self.use_experimental_grid_checkbox.isChecked(): + self.use_experimental_grid_checkbox.setChecked(False) + self._update_resample_grid_state() + self._refresh_data_status_labels() + self._update_data_trace_control_state() + self._redraw_saxs_preview() + def experimental_header_rows(self) -> int: return int(self._experimental_header_rows) @@ -968,7 +1072,31 @@ def q_min(self) -> float | None: def q_max(self) -> float | None: return self._optional_float(self.qmax_edit.text()) + def default_experimental_q_range(self) -> tuple[float, float] | None: + if self.model_only_mode() or self._experimental_summary is None: + return None + q_values = np.asarray(self._experimental_summary.q_values, dtype=float) + if q_values.size == 0: + return None + return float(q_values.min()), float(q_values.max()) + + def q_range_matches_loaded_experimental_defaults(self) -> bool: + default_range = self.default_experimental_q_range() + if default_range is None: + return False + q_min = self.q_min() + q_max = self.q_max() + if q_min is None or q_max is None: + return False + default_q_min, default_q_max = default_range + return bool( + np.isclose(q_min, default_q_min, rtol=0.0, atol=1e-9) + and np.isclose(q_max, default_q_max, rtol=0.0, atol=1e-9) + ) + def use_experimental_grid(self) -> bool: + if self.model_only_mode(): + return False return bool(self.use_experimental_grid_checkbox.isChecked()) def q_points(self) -> int | None: @@ -985,6 +1113,10 @@ def prior_secondary_element(self) -> str | None: text = self.secondary_filter_combo.currentText().strip() return text or None + def selected_prior_secondary_element(self) -> str | None: + text = self.secondary_filter_combo.currentText().strip() + return text or None + def prior_cmap(self) -> str: return self.prior_color_combo.currentText().strip() or "summer" @@ -1020,7 +1152,14 @@ def prior_structure_motif_colors(self) -> dict[str, str] | None: return colors or None def append_summary(self, message: str) -> None: - self.summary_box.append(message) + stripped = str(message).strip() + if not stripped: + return + if self._summary_text: + self._summary_text += "\n" + stripped + else: + self._summary_text = stripped + self._render_summary_text() def component_plot_export_payload(self) -> dict[str, object]: traces: list[dict[str, object]] = [] @@ -1102,7 +1241,48 @@ def set_solvent_trace_settings( self._update_data_trace_control_state() def set_summary_text(self, text: str) -> None: - self.summary_box.setPlainText(text) + self._summary_text = str(text).strip() + self._render_summary_text() + + def set_console_autoscroll_enabled(self, enabled: bool) -> None: + self._console_autoscroll_enabled = bool(enabled) + if self._console_autoscroll_enabled: + self._scroll_summary_to_end() + + def _render_summary_text(self) -> None: + scrollbar = self.summary_box.verticalScrollBar() + previous_value = scrollbar.value() + previous_maximum = max(scrollbar.maximum(), 1) + self.summary_box.setPlainText(self._summary_text) + if self._console_autoscroll_enabled: + self._scroll_summary_to_end() + return + updated_scrollbar = self.summary_box.verticalScrollBar() + if updated_scrollbar.maximum() > 0: + position_fraction = previous_value / previous_maximum + updated_scrollbar.setValue( + int(round(position_fraction * updated_scrollbar.maximum())) + ) + + def _scroll_summary_to_end(self) -> None: + cursor = self.summary_box.textCursor() + cursor.movePosition(QTextCursor.MoveOperation.End) + self.summary_box.setTextCursor(cursor) + self.summary_box.ensureCursorVisible() + scrollbar = self.summary_box.verticalScrollBar() + scrollbar.setValue(scrollbar.maximum()) + QTimer.singleShot( + 0, + self._scroll_summary_to_end_once, + ) + + def _scroll_summary_to_end_once(self) -> None: + cursor = self.summary_box.textCursor() + cursor.movePosition(QTextCursor.MoveOperation.End) + self.summary_box.setTextCursor(cursor) + self.summary_box.ensureCursorVisible() + scrollbar = self.summary_box.verticalScrollBar() + scrollbar.setValue(scrollbar.maximum()) def draw_component_plot(self, component_paths: list[Path] | None) -> None: self._component_paths = component_paths @@ -1182,7 +1362,8 @@ def _redraw_saxs_preview(self) -> None: self._component_legend_lookup.clear() self._component_line_lookup.clear() self._component_color_lookup.clear() - has_data_preview = ( + show_data_preview = not self.model_only_mode() + has_data_preview = show_data_preview and ( self._experimental_summary is not None or self._solvent_summary is not None ) @@ -1193,9 +1374,14 @@ def _redraw_saxs_preview(self) -> None: 0.5, 0.5, fill( - "Select experimental data and build SAXS components " - "to preview the experimental range and averaged " - "cluster profiles.", + ( + "Build SAXS components to preview averaged cluster " + "profiles in Model Only Mode." + if self.model_only_mode() + else "Select experimental data and build SAXS " + "components to preview the experimental range and " + "averaged cluster profiles." + ), width=42, ), ha="center", @@ -1221,7 +1407,7 @@ def _redraw_saxs_preview(self) -> None: plotted_lines.extend( self._draw_experimental_preview( experimental_axis, - self._experimental_summary, + self._experimental_summary if show_data_preview else None, ) ) @@ -1484,22 +1670,31 @@ def _trace_color_button_value( def _experimental_data_row(self) -> QWidget: row = QWidget() + self.experimental_data_row_widget = row layout = QHBoxLayout(row) layout.setContentsMargins(0, 0, 0, 0) layout.addWidget(self.experimental_data_edit, stretch=1) - file_button = QPushButton("File…") - file_button.clicked.connect(self._choose_experimental_file) - folder_button = QPushButton("Folder…") - folder_button.clicked.connect(self._choose_experimental_folder) - columns_button = QPushButton("Columns…") - columns_button.clicked.connect(self._configure_experimental_columns) - clear_button = QPushButton("Clear") - clear_button.clicked.connect(self._clear_experimental_selection) - layout.addWidget(file_button) - layout.addWidget(folder_button) - layout.addWidget(columns_button) - layout.addWidget(clear_button) + self.experimental_file_button = QPushButton("File…") + self.experimental_file_button.clicked.connect( + self._choose_experimental_file + ) + self.experimental_folder_button = QPushButton("Folder…") + self.experimental_folder_button.clicked.connect( + self._choose_experimental_folder + ) + self.experimental_columns_button = QPushButton("Columns…") + self.experimental_columns_button.clicked.connect( + self._configure_experimental_columns + ) + self.experimental_clear_button = QPushButton("Clear") + self.experimental_clear_button.clicked.connect( + self._clear_experimental_selection + ) + layout.addWidget(self.experimental_file_button) + layout.addWidget(self.experimental_folder_button) + layout.addWidget(self.experimental_columns_button) + layout.addWidget(self.experimental_clear_button) self.experimental_trace_visible_checkbox = QCheckBox() self.experimental_trace_visible_checkbox.setChecked(True) self.experimental_trace_visible_checkbox.toggled.connect( @@ -1522,19 +1717,24 @@ def _experimental_data_row(self) -> QWidget: def _solvent_data_row(self) -> QWidget: row = QWidget() + self.solvent_data_row_widget = row layout = QHBoxLayout(row) layout.setContentsMargins(0, 0, 0, 0) layout.addWidget(self.solvent_data_edit, stretch=1) - file_button = QPushButton("File…") - file_button.clicked.connect(self._choose_solvent_file) - columns_button = QPushButton("Columns…") - columns_button.clicked.connect(self._configure_solvent_columns) - clear_button = QPushButton("Clear") - clear_button.clicked.connect(self._clear_solvent_selection) - layout.addWidget(file_button) - layout.addWidget(columns_button) - layout.addWidget(clear_button) + self.solvent_file_button = QPushButton("File…") + self.solvent_file_button.clicked.connect(self._choose_solvent_file) + self.solvent_columns_button = QPushButton("Columns…") + self.solvent_columns_button.clicked.connect( + self._configure_solvent_columns + ) + self.solvent_clear_button = QPushButton("Clear") + self.solvent_clear_button.clicked.connect( + self._clear_solvent_selection + ) + layout.addWidget(self.solvent_file_button) + layout.addWidget(self.solvent_columns_button) + layout.addWidget(self.solvent_clear_button) self.solvent_trace_visible_checkbox = QCheckBox() self.solvent_trace_visible_checkbox.setChecked(True) self.solvent_trace_visible_checkbox.toggled.connect( @@ -1581,7 +1781,7 @@ def _browse_existing_project_directory(self) -> None: project_file = build_project_paths(project_dir).project_file if not project_file.is_file(): self.open_project_dir_edit.clear() - self.summary_box.append( + self.append_summary( "Select a complete SAXS project folder that contains " "saxs_project.json, not a parent directory of multiple projects." ) @@ -1670,10 +1870,7 @@ def _clear_experimental_selection(self) -> None: self._experimental_intensity_column = None self._experimental_error_column = None self._experimental_summary = None - self.data_status_label.setText( - "No experimental data selected.\n" - "Choose an experimental SAXS file or folder to preview its range." - ) + self._refresh_data_status_labels() self._update_data_trace_control_state() self._redraw_saxs_preview() self.autosave_project_requested.emit("cleared experimental data") @@ -1715,11 +1912,7 @@ def _clear_solvent_selection(self) -> None: self._solvent_intensity_column = None self._solvent_error_column = None self._solvent_summary = None - self.solvent_status_label.setText( - "Optional solvent SAXS data can be loaded here and will be " - "carried into prefit and DREAM if the active model uses " - "solvent intensities." - ) + self._refresh_data_status_labels() self._update_data_trace_control_state() self._redraw_saxs_preview() self.autosave_project_requested.emit("cleared solvent data") @@ -1797,17 +1990,15 @@ def _set_experimental_status( self, summary: ExperimentalDataSummary, ) -> None: - self.data_status_label.setText( - self._experimental_import_summary(summary) - ) + del summary + self._refresh_data_status_labels() def _set_solvent_status( self, summary: ExperimentalDataSummary, ) -> None: - self.solvent_status_label.setText( - self._experimental_import_summary(summary) - ) + del summary + self._refresh_data_status_labels() def _load_experimental_summary_from_path( self, @@ -1944,8 +2135,12 @@ def _configure_solvent_columns(self) -> None: self._apply_solvent_file(selected_path, accepted_summary) def _update_data_trace_control_state(self) -> None: - has_experimental = self._experimental_summary is not None - has_solvent = self._solvent_summary is not None + if self.model_only_mode(): + has_experimental = False + has_solvent = False + else: + has_experimental = self._experimental_summary is not None + has_solvent = self._solvent_summary is not None self.experimental_trace_visible_checkbox.setEnabled(has_experimental) self.experimental_trace_color_button.setEnabled(has_experimental) self.solvent_trace_visible_checkbox.setEnabled(has_solvent) @@ -1987,6 +2182,17 @@ def _on_clusters_dir_edited(self) -> None: "cleared the clusters folder reference" ) + def _on_model_only_mode_toggled(self, enabled: bool) -> None: + if enabled and self.use_experimental_grid_checkbox.isChecked(): + self.use_experimental_grid_checkbox.setChecked(False) + self._apply_model_only_mode_state() + self.model_only_mode_changed.emit(bool(enabled)) + self.autosave_project_requested.emit( + "enabled Model Only Mode" + if enabled + else "disabled Model Only Mode" + ) + def _update_resample_grid_state(self) -> None: self.resample_points_spin.setEnabled( not self.use_experimental_grid_checkbox.isChecked() diff --git a/src/saxshell/saxs/ui/solute_volume_fraction_widget.py b/src/saxshell/saxs/ui/solute_volume_fraction_widget.py index ed2a6b5..f0c37c4 100644 --- a/src/saxshell/saxs/ui/solute_volume_fraction_widget.py +++ b/src/saxshell/saxs/ui/solute_volume_fraction_widget.py @@ -1,667 +1,27 @@ from __future__ import annotations -from PySide6.QtCore import Qt, Signal -from PySide6.QtWidgets import ( - QComboBox, - QDoubleSpinBox, - QFormLayout, - QGridLayout, - QGroupBox, - QHBoxLayout, - QLabel, - QLineEdit, - QPushButton, - QStackedWidget, - QTextEdit, - QToolButton, - QVBoxLayout, - QWidget, +from saxshell.saxs.ui.solution_scattering_widget import ( + SOLUTE_VOLUME_FRACTION_CITATION_URL, + SOLUTE_VOLUME_FRACTION_HELP_TEXT, + SOLUTION_SCATTERING_HELP_TEXT, + AttenuationEstimateToolWindow, + FluorescenceEstimateToolWindow, + NumberDensityEstimateToolWindow, + SoluteVolumeFractionToolWindow, + SoluteVolumeFractionWidget, + SolutionScatteringEstimatorWidget, + SolutionScatteringToolWindow, ) -from saxshell.fullrmc.solution_properties import ( - SolutionPropertiesSettings, - solution_properties_mode_hint_text, -) -from saxshell.fullrmc.solution_property_presets import ( - SolutionPropertiesPreset, - load_solution_property_presets, - ordered_solution_property_preset_names, - solution_property_presets_path, -) -from saxshell.saxs.solute_volume_fraction import ( - SoluteVolumeFractionEstimate, - SoluteVolumeFractionSettings, - calculate_solute_volume_fraction_estimate, -) - -SOLUTION_MODE_ITEMS = ( - ("Masses", "mass"), - ("Mass Percent", "mass_percent"), - ("Molarity (per liter)", "molarity_per_liter"), -) - -SOLUTE_VOLUME_FRACTION_HELP_TEXT = ( - "Solute volume fraction estimator\n\n" - "This calculator uses the current solution-composition model to recover " - "solute and solvent masses, then estimates the solute fraction in the " - "measured solution volume using a SAXS-style concentration x specific-" - "volume relation:\n\n" - "c_solute = m_solute / V_solution\n" - "vbar_solute ~= 1 / rho_solute\n" - "phi_solute ~= c_solute * vbar_solute\n" - " = V_solute / V_solution\n\n" - "The widget still reports additive component volumes as a diagnostic, " - "but the main fitted fraction now uses the measured solution volume in " - "the denominator rather than V_solute + V_solvent.\n\n" - "Additive-volume check:\n" - "V_solute = m_solute / rho_solute\n" - "V_solvent = m_solvent / rho_solvent\n" - "V_additive = V_solute + V_solvent\n\n" - "The solution presets come from the fullrmc Solution Properties tool. " - "Those presets populate the composition fields, but the relevant pure-" - "component densities should still be reviewed for this estimate. In " - "molarity mode, SAXSShell uses solution-density plus solvent-density " - "closure, so solvent density stays active and solute density is not " - "required.\n\n" - "Citation link:\n" - "Hajizadeh et al. (2018), concentration-dependent SAXS mass estimates " - "require calibrated intensity, accurate solute concentration, and partial " - "specific volume.\n" - "https://www.nature.com/articles/s41598-018-25355-2" -) - -SOLUTE_VOLUME_FRACTION_CITATION_URL = ( - "https://www.nature.com/articles/s41598-018-25355-2" -) - - -class SoluteVolumeFractionWidget(QWidget): - estimate_calculated = Signal(object) - estimate_failed = Signal(str) - - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__(parent) - self._solution_presets: dict[str, SolutionPropertiesPreset] = {} - self._updating_solution_preset_selection = False - self._current_estimate: SoluteVolumeFractionEstimate | None = None - self._build_ui() - self._reload_presets() - - def _build_ui(self) -> None: - layout = QVBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(8) - - self.target_label = QLabel( - "This calculator is not currently linked to a Prefit " - "solute/solvent fraction parameter." - ) - self.target_label.setWordWrap(True) - layout.addWidget(self.target_label) - - preset_group = QGroupBox("Solution Presets") - preset_layout = QVBoxLayout(preset_group) - preset_row = QHBoxLayout() - self.solution_preset_combo = QComboBox() - preset_row.addWidget(self.solution_preset_combo, stretch=1) - self.load_solution_preset_button = QPushButton("Load") - self.load_solution_preset_button.clicked.connect( - self._load_selected_solution_preset - ) - preset_row.addWidget(self.load_solution_preset_button) - preset_layout.addLayout(preset_row) - preset_hint = QLabel( - "These presets reuse the fullrmc Solution Properties inputs. " - "The density fields relevant to the selected estimator mode " - "remain editable below.\n" - f"Preset file: {solution_property_presets_path()}" - ) - preset_hint.setWordWrap(True) - preset_layout.addWidget(preset_hint) - layout.addWidget(preset_group) - - self.solution_mode_combo = QComboBox() - for label, value in SOLUTION_MODE_ITEMS: - self.solution_mode_combo.addItem(label, userData=value) - self.solution_mode_combo.currentIndexChanged.connect( - self._on_solution_mode_changed - ) - self.solution_mode_combo.currentIndexChanged.connect( - self._on_solution_settings_changed - ) - - self.solution_density_spin = self._new_float_spin( - maximum=100.0, - step=0.01, - decimals=6, - value=1.0, - ) - self.solution_density_spin.valueChanged.connect( - self._on_solution_settings_changed - ) - - self.solute_stoich_edit = QLineEdit() - self.solute_stoich_edit.setPlaceholderText("e.g. Cs1Pb1I3") - self.solute_stoich_edit.textChanged.connect( - self._on_solution_settings_changed - ) - - self.solvent_stoich_edit = QLineEdit() - self.solvent_stoich_edit.setPlaceholderText("e.g. H2O or C3H7NO") - self.solvent_stoich_edit.textChanged.connect( - self._on_solution_settings_changed - ) - - self.molar_mass_solute_spin = self._new_float_spin( - maximum=1_000_000.0, - step=1.0, - decimals=6, - value=0.0, - ) - self.molar_mass_solute_spin.valueChanged.connect( - self._on_solution_settings_changed - ) - - self.molar_mass_solvent_spin = self._new_float_spin( - maximum=1_000_000.0, - step=1.0, - decimals=6, - value=0.0, - ) - self.molar_mass_solvent_spin.valueChanged.connect( - self._on_solution_settings_changed - ) - - self.solute_density_spin = self._new_float_spin( - maximum=100.0, - step=0.01, - decimals=6, - value=1.0, - ) - self.solute_density_spin.valueChanged.connect( - self._on_solution_settings_changed - ) - - self.solvent_density_spin = self._new_float_spin( - maximum=100.0, - step=0.01, - decimals=6, - value=1.0, - ) - self.solvent_density_spin.valueChanged.connect( - self._on_solution_settings_changed - ) - - fields_layout = QGridLayout() - fields_layout.setColumnStretch(1, 1) - fields_layout.setColumnStretch(3, 1) - self.solution_mode_label = QLabel("Input mode") - fields_layout.addWidget(self.solution_mode_label, 0, 0) - fields_layout.addWidget(self.solution_mode_combo, 0, 1) - self.solution_density_label = QLabel("Solution density (g/mL)") - fields_layout.addWidget(self.solution_density_label, 0, 2) - fields_layout.addWidget(self.solution_density_spin, 0, 3) - self.solute_stoich_label = QLabel("Solute stoichiometry") - fields_layout.addWidget(self.solute_stoich_label, 1, 0) - fields_layout.addWidget(self.solute_stoich_edit, 1, 1) - self.solvent_stoich_label = QLabel("Solvent stoichiometry") - fields_layout.addWidget(self.solvent_stoich_label, 1, 2) - fields_layout.addWidget(self.solvent_stoich_edit, 1, 3) - self.molar_mass_solute_label = QLabel("Solute molar mass (g/mol)") - fields_layout.addWidget(self.molar_mass_solute_label, 2, 0) - fields_layout.addWidget(self.molar_mass_solute_spin, 2, 1) - self.molar_mass_solvent_label = QLabel("Solvent molar mass (g/mol)") - fields_layout.addWidget(self.molar_mass_solvent_label, 2, 2) - fields_layout.addWidget(self.molar_mass_solvent_spin, 2, 3) - self.solute_density_label = QLabel("Solute density (g/mL)") - fields_layout.addWidget(self.solute_density_label, 3, 0) - fields_layout.addWidget(self.solute_density_spin, 3, 1) - self.solvent_density_label = QLabel("Solvent density (g/mL)") - fields_layout.addWidget(self.solvent_density_label, 3, 2) - fields_layout.addWidget(self.solvent_density_spin, 3, 3) - layout.addLayout(fields_layout) - self.solution_mode_hint_label = QLabel() - self.solution_mode_hint_label.setWordWrap(True) - layout.addWidget(self.solution_mode_hint_label) - - self.solution_mode_stack = QStackedWidget() - - mass_page = QWidget() - mass_form = QFormLayout(mass_page) - self.mass_solute_spin = self._new_float_spin( - maximum=1_000_000.0, - step=0.1, - decimals=6, - value=0.0, - ) - self.mass_solute_spin.valueChanged.connect( - self._on_solution_settings_changed - ) - mass_form.addRow("Mass solute (g)", self.mass_solute_spin) - self.mass_solvent_spin = self._new_float_spin( - maximum=1_000_000.0, - step=0.1, - decimals=6, - value=0.0, - ) - self.mass_solvent_spin.valueChanged.connect( - self._on_solution_settings_changed - ) - mass_form.addRow("Mass solvent (g)", self.mass_solvent_spin) - self.solution_mode_stack.addWidget(mass_page) - - mass_percent_page = QWidget() - mass_percent_form = QFormLayout(mass_percent_page) - self.mass_percent_solute_spin = self._new_float_spin( - maximum=100.0, - step=0.1, - decimals=4, - value=0.0, - ) - self.mass_percent_solute_spin.valueChanged.connect( - self._on_solution_settings_changed - ) - mass_percent_form.addRow( - "Mass percent solute (%)", - self.mass_percent_solute_spin, - ) - self.total_mass_solution_spin = self._new_float_spin( - maximum=1_000_000.0, - step=0.1, - decimals=6, - value=0.0, - ) - self.total_mass_solution_spin.valueChanged.connect( - self._on_solution_settings_changed - ) - mass_percent_form.addRow( - "Total solution mass (g)", - self.total_mass_solution_spin, - ) - self.solution_mode_stack.addWidget(mass_percent_page) - - molarity_page = QWidget() - molarity_form = QFormLayout(molarity_page) - self.molarity_spin = self._new_float_spin( - maximum=100_000.0, - step=0.01, - decimals=6, - value=0.0, - ) - self.molarity_spin.valueChanged.connect( - self._on_solution_settings_changed - ) - molarity_form.addRow("Molarity (mol/L)", self.molarity_spin) - self.molarity_element_edit = QLineEdit() - self.molarity_element_edit.setPlaceholderText("e.g. Pb") - self.molarity_element_edit.textChanged.connect( - self._on_solution_settings_changed - ) - molarity_form.addRow( - "Molarity element", - self.molarity_element_edit, - ) - self.solution_mode_stack.addWidget(molarity_page) - - layout.addWidget(self.solution_mode_stack) - - button_row = QHBoxLayout() - self.calculate_button = QPushButton("Calculate Volume Fraction") - self.calculate_button.clicked.connect(self._calculate_estimate) - button_row.addWidget(self.calculate_button) - button_row.addStretch(1) - layout.addLayout(button_row) - - output_header = QHBoxLayout() - self.output_toggle_button = QToolButton() - self.output_toggle_button.setToolButtonStyle( - Qt.ToolButtonStyle.ToolButtonTextBesideIcon - ) - self.output_toggle_button.setAutoRaise(True) - self.output_toggle_button.clicked.connect( - self._toggle_output_collapsed - ) - output_header.addWidget(self.output_toggle_button) - output_header.addStretch(1) - layout.addLayout(output_header) - - self.output_box = QTextEdit() - self.output_box.setReadOnly(True) - self.output_box.setMinimumHeight(180) - layout.addWidget(self.output_box) - self.set_output_collapsed(True) - self._update_solution_mode_widgets() - - @staticmethod - def _new_float_spin( - *, - maximum: float, - step: float, - decimals: int, - value: float, - ) -> QDoubleSpinBox: - spin = QDoubleSpinBox() - spin.setRange(0.0, maximum) - spin.setDecimals(decimals) - spin.setSingleStep(step) - spin.setValue(value) - return spin - - def set_target_parameter( - self, - parameter_name: str | None, - fraction_kind: str | None, - ) -> None: - if parameter_name and fraction_kind: - label = "solute" if fraction_kind == "solute" else "solvent" - self.target_label.setText( - "Active Prefit target: " - f"{parameter_name} ({label} volume fraction)." - ) - else: - self.target_label.setText( - "This calculator is not currently linked to a Prefit " - "solute/solvent fraction parameter." - ) - - def append_application_note(self, message: str) -> None: - self.set_output_collapsed(False) - text = self.output_box.toPlainText().strip() - if not text: - self.output_box.setPlainText(message.strip()) - return - self.output_box.setPlainText(text + "\n\n" + message.strip()) - - def current_estimate(self) -> SoluteVolumeFractionEstimate | None: - return self._current_estimate - - def _reload_presets(self, *, selected_name: str | None = None) -> None: - previous_name = selected_name or self._selected_solution_preset_name() - self._solution_presets = load_solution_property_presets() - self.solution_preset_combo.blockSignals(True) - self.solution_preset_combo.clear() - self.solution_preset_combo.addItem("Current values", None) - selected_index = 0 - for index, name in enumerate( - ordered_solution_property_preset_names(self._solution_presets), - start=1, - ): - preset = self._solution_presets[name] - label = f"{name} (Built-in)" if preset.builtin else name - self.solution_preset_combo.addItem(label, name) - if name == previous_name: - selected_index = index - self.solution_preset_combo.setCurrentIndex(selected_index) - self.solution_preset_combo.blockSignals(False) - - def _selected_solution_preset_name(self) -> str | None: - payload = self.solution_preset_combo.currentData() - if payload is None: - return None - return str(payload) - - def _load_selected_solution_preset(self) -> None: - preset_name = self._selected_solution_preset_name() - if preset_name is None: - self.append_application_note("Select a solution preset to load.") - return - preset = self._solution_presets.get(preset_name) - if preset is None: - self.append_application_note( - f"Unknown solution preset: {preset_name}" - ) - return - self._apply_solution_preset(preset) - self._select_solution_preset_name(preset.name) - - def _apply_solution_preset( - self, - preset: SolutionPropertiesPreset, - ) -> None: - self._apply_solution_settings(preset.settings) - if preset.solute_density_g_per_ml is not None: - self.solute_density_spin.setValue(preset.solute_density_g_per_ml) - if preset.solvent_density_g_per_ml is not None: - self.solvent_density_spin.setValue(preset.solvent_density_g_per_ml) - - def _apply_solution_settings( - self, - settings: SolutionPropertiesSettings, - ) -> None: - previous_updating = self._updating_solution_preset_selection - self._updating_solution_preset_selection = True - try: - self._set_combo_value(self.solution_mode_combo, settings.mode) - self.solution_density_spin.setValue(settings.solution_density) - self.solute_stoich_edit.setText(settings.solute_stoich) - self.solvent_stoich_edit.setText(settings.solvent_stoich) - self.molar_mass_solute_spin.setValue(settings.molar_mass_solute) - self.molar_mass_solvent_spin.setValue(settings.molar_mass_solvent) - self.mass_solute_spin.setValue(settings.mass_solute) - self.mass_solvent_spin.setValue(settings.mass_solvent) - self.mass_percent_solute_spin.setValue( - settings.mass_percent_solute - ) - self.total_mass_solution_spin.setValue( - settings.total_mass_solution - ) - self.molarity_spin.setValue(settings.molarity) - self.molarity_element_edit.setText(settings.molarity_element) - self._update_solution_mode_widgets() - finally: - self._updating_solution_preset_selection = previous_updating - - @staticmethod - def _set_combo_value(combo: QComboBox, value: str) -> None: - index = combo.findData(value) - if index >= 0: - combo.setCurrentIndex(index) - - def _selected_solution_mode(self) -> str: - return str(self.solution_mode_combo.currentData() or "mass") - - def _current_solution_settings(self) -> SolutionPropertiesSettings: - return SolutionPropertiesSettings( - mode=self._selected_solution_mode(), - solution_density=float(self.solution_density_spin.value()), - solute_stoich=self.solute_stoich_edit.text().strip(), - solvent_stoich=self.solvent_stoich_edit.text().strip(), - molar_mass_solute=float(self.molar_mass_solute_spin.value()), - molar_mass_solvent=float(self.molar_mass_solvent_spin.value()), - mass_solute=float(self.mass_solute_spin.value()), - mass_solvent=float(self.mass_solvent_spin.value()), - mass_percent_solute=float(self.mass_percent_solute_spin.value()), - total_mass_solution=float(self.total_mass_solution_spin.value()), - molarity=float(self.molarity_spin.value()), - molarity_element=self.molarity_element_edit.text().strip(), - ) - - def current_estimator_settings(self) -> SoluteVolumeFractionSettings: - mode = self._selected_solution_mode() - return SoluteVolumeFractionSettings( - solution=self._current_solution_settings(), - solute_density_g_per_ml=( - None - if mode == "molarity_per_liter" - else float(self.solute_density_spin.value()) - ), - solvent_density_g_per_ml=float(self.solvent_density_spin.value()), - ) - - def _on_solution_mode_changed(self) -> None: - self._update_solution_mode_widgets() - - def _update_solution_mode_widgets(self) -> None: - selected_mode = self._selected_solution_mode() - mode_to_index = { - "mass": 0, - "mass_percent": 1, - "molarity_per_liter": 2, - } - self.solution_mode_stack.setCurrentIndex( - mode_to_index.get(selected_mode, 0) - ) - show_solute_density = selected_mode != "molarity_per_liter" - show_solvent_density = True - self.solute_density_label.setVisible(show_solute_density) - self.solute_density_spin.setVisible(show_solute_density) - self.solvent_density_label.setVisible(show_solvent_density) - self.solvent_density_spin.setVisible(show_solvent_density) - self.solution_mode_hint_label.setText( - self._estimator_mode_hint_text(selected_mode) - ) - - @staticmethod - def _estimator_mode_hint_text(mode: str) -> str: - base = solution_properties_mode_hint_text(mode) - if mode == "molarity_per_liter": - return ( - f"{base} For this volume-fraction estimate, molarity mode " - "uses solvent-density closure: solvent density stays visible " - "so SAXSShell can estimate V_solvent = m_solvent / " - "rho_solvent and then V_solute ~= V_solution - V_solvent. " - "Solute density is hidden in molarity mode." - ) - return ( - f"{base} In these modes, both pure-component densities remain " - "visible so the additive solute and solvent volumes can be " - "reported alongside the fitted fraction estimate." - ) - - def _on_solution_settings_changed(self, *_args: object) -> None: - if self._updating_solution_preset_selection: - return - self._select_solution_preset_name( - self._matching_solution_preset_name( - self._current_solution_settings() - ) - ) - - def _select_solution_preset_name(self, preset_name: str | None) -> None: - target_index = 0 - if preset_name is not None: - for index in range(self.solution_preset_combo.count()): - if self.solution_preset_combo.itemData(index) == preset_name: - target_index = index - break - previous_updating = self._updating_solution_preset_selection - self._updating_solution_preset_selection = True - try: - self.solution_preset_combo.setCurrentIndex(target_index) - finally: - self._updating_solution_preset_selection = previous_updating - - def _matching_solution_preset_name( - self, - settings: SolutionPropertiesSettings, - ) -> str | None: - for name in ordered_solution_property_preset_names( - self._solution_presets - ): - preset = self._solution_presets.get(name) - if preset is None: - continue - if self._solution_settings_match(settings, preset.settings): - return name - return None - - @staticmethod - def _solution_settings_match( - left: SolutionPropertiesSettings, - right: SolutionPropertiesSettings, - ) -> bool: - float_fields = ( - "solution_density", - "molar_mass_solute", - "molar_mass_solvent", - "mass_solute", - "mass_solvent", - "mass_percent_solute", - "total_mass_solution", - "molarity", - ) - text_fields = ( - "mode", - "solute_stoich", - "solvent_stoich", - "molarity_element", - ) - for field_name in float_fields: - if ( - abs( - float(getattr(left, field_name)) - - float(getattr(right, field_name)) - ) - > 1e-9 - ): - return False - for field_name in text_fields: - if str(getattr(left, field_name)) != str( - getattr(right, field_name) - ): - return False - return True - - def _calculate_estimate(self) -> None: - try: - estimate = calculate_solute_volume_fraction_estimate( - self.current_estimator_settings() - ) - except Exception as exc: - message = f"Unable to estimate the volume fraction: {exc}" - self.set_output_collapsed(False) - self.output_box.setPlainText(message) - self._current_estimate = None - self.estimate_failed.emit(str(exc)) - return - self._current_estimate = estimate - self.set_output_collapsed(False) - self.output_box.setPlainText(estimate.summary_text()) - self.estimate_calculated.emit(estimate) - - def output_is_collapsed(self) -> bool: - return self.output_box.isHidden() - - def set_output_collapsed(self, collapsed: bool) -> None: - is_collapsed = bool(collapsed) - self.output_box.setVisible(not is_collapsed) - self.output_toggle_button.setArrowType( - Qt.ArrowType.RightArrow if is_collapsed else Qt.ArrowType.DownArrow - ) - self.output_toggle_button.setText( - "Show Output" if is_collapsed else "Hide Output" - ) - - def _toggle_output_collapsed(self) -> None: - self.set_output_collapsed(not self.output_box.isHidden()) - - -class SoluteVolumeFractionToolWindow(QWidget): - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__(parent) - self.setWindowTitle("Volume Fraction Estimate") - self._build_ui() - - def _build_ui(self) -> None: - layout = QVBoxLayout(self) - citation_label = QLabel( - "Estimate the physical solute or solvent volume fraction from " - "solution composition plus pure-component densities using a " - "SAXS-style concentration x specific-volume estimate. " - f'' - "Citation: Hajizadeh et al. (2018)" - ) - citation_label.setWordWrap(True) - citation_label.setOpenExternalLinks(True) - layout.addWidget(citation_label) - self.estimator_widget = SoluteVolumeFractionWidget(self) - layout.addWidget(self.estimator_widget) - self.resize(720, 760) - - __all__ = [ + "AttenuationEstimateToolWindow", + "FluorescenceEstimateToolWindow", + "NumberDensityEstimateToolWindow", "SOLUTE_VOLUME_FRACTION_CITATION_URL", "SOLUTE_VOLUME_FRACTION_HELP_TEXT", + "SOLUTION_SCATTERING_HELP_TEXT", + "SolutionScatteringEstimatorWidget", + "SolutionScatteringToolWindow", "SoluteVolumeFractionToolWindow", "SoluteVolumeFractionWidget", ] diff --git a/src/saxshell/saxs/ui/solution_scattering_widget.py b/src/saxshell/saxs/ui/solution_scattering_widget.py new file mode 100644 index 0000000..c20d3e0 --- /dev/null +++ b/src/saxshell/saxs/ui/solution_scattering_widget.py @@ -0,0 +1,1354 @@ +from __future__ import annotations + +from PySide6.QtCore import QSize, Qt, Signal +from PySide6.QtWidgets import ( + QApplication, + QCheckBox, + QComboBox, + QDialog, + QDialogButtonBox, + QDoubleSpinBox, + QFormLayout, + QFrame, + QGridLayout, + QGroupBox, + QHBoxLayout, + QInputDialog, + QLabel, + QLineEdit, + QMessageBox, + QPushButton, + QScrollArea, + QStackedWidget, + QTextEdit, + QToolButton, + QVBoxLayout, + QWidget, +) + +from saxshell.fullrmc.solution_properties import ( + SolutionPropertiesSettings, + solution_properties_mode_hint_text, +) +from saxshell.fullrmc.solution_property_presets import ( + SolutionPropertiesPreset, + load_solution_property_presets, + ordered_solution_property_preset_names, +) +from saxshell.saxs.beam_geometry_presets import ( + DEFAULT_BEAM_GEOMETRY_PRESET_NAME, + BeamGeometryPreset, + delete_custom_beam_geometry_preset, + load_beam_geometry_presets, + ordered_beam_geometry_preset_names, + save_custom_beam_geometry_preset, +) +from saxshell.saxs.solution_scattering_estimator import ( + BEAM_PROFILE_ITEMS, + CAPILLARY_GEOMETRY_ITEMS, + DEFAULT_BEAM_FOOTPRINT_HEIGHT_MM, + DEFAULT_BEAM_FOOTPRINT_WIDTH_MM, + DEFAULT_BEAM_PROFILE, + DEFAULT_CAPILLARY_GEOMETRY, + DEFAULT_CAPILLARY_SIZE_MM, + DEFAULT_INCIDENT_ENERGY_KEV, + BeamGeometrySettings, + SolutionScatteringEstimate, + SolutionScatteringEstimatorSettings, + calculate_solution_scattering_estimate, + wavelength_angstrom_from_energy_kev, +) + +SOLUTION_MODE_ITEMS = ( + ("Masses", "mass"), + ("Mass Percent", "mass_percent"), + ("Molarity (per liter)", "molarity_per_liter"), +) + +SOLUTE_VOLUME_FRACTION_CITATION_URL = ( + "https://www.nature.com/articles/s41598-018-25355-2" +) +XRAYDB_REFERENCE_URL = "https://scikit-beam.github.io/XrayDB/python.html" +NIST_ATTENUATION_REFERENCE_URL = "https://doi.org/10.6028/NBS.NSRDS.29" +XRF_FORWARD_MODEL_REFERENCE_URL = ( + "https://pmc.ncbi.nlm.nih.gov/articles/PMC12871215/" +) +SELF_ABSORPTION_REFERENCE_URL = ( + "https://pmc.ncbi.nlm.nih.gov/articles/PMC6608621/" +) + +SOLUTION_SCATTERING_HELP_TEXT = ( + "Solution scattering estimators\n\n" + "This widget combines five related calculations that all start from the " + "same solution-composition inputs.\n\n" + "1. Number density\n" + " n = N_atoms / V_solution\n" + " SAXSShell reports the result in atoms/A^3.\n\n" + "2. Physical solute-associated volume fraction\n" + " phi_phys ~= c_solute * vbar_solute\n" + " = (m_solute / V_solution) * (1 / rho_solute)\n" + " The physical solvent-associated fraction is reported as\n" + " 1 - phi_phys. This bulk-density estimate stays in the output\n" + " console for reference.\n\n" + "3. SAXS-effective interaction contrast ratio at energy E\n" + " rho_eff(E) = rho_mass * N_A / M * sum_i n_i [Z_i + f'_i(E)]\n" + " C(E) = ((rho_eff,solute(E) - rho_eff,solvent(E))\n" + " / rho_eff,solvent(E))^2\n" + " V_eff,solute(E) = C(E) * V_solute,phys\n" + " R_saxs(E) = V_eff,solute(E)\n" + " / (V_eff,solute(E) + V_solvent,phys)\n" + " This contrast-weighted ratio is the default model-facing\n" + " solute fraction for phi_solute / phi_solvent.\n\n" + "4. Attenuation and solvent contribution scaling\n" + " mu(E) ~= c_solute * (mu/rho)_solute(E)\n" + " + c_solvent * (mu/rho)_solvent(E)\n" + " T(E, L) = exp(-mu(E) * L)\n" + " For SAXS transmission geometry, SAXSShell estimates the solvent " + " scattering scale factor from the ratio of beam-profile-averaged " + " L * exp(-mu * L) terms for the solvent in the sample versus the " + " neat solvent reference. If a template only exposes a single\n" + " solvent-weight parameter, SAXSShell recommends\n" + " w_model = (1 - R_saxs(E)) * w_att.\n\n" + "5. Fluorescence background proxy\n" + " SAXSShell estimates primary fluorescence from element-resolved " + " photoelectric attenuation, edge jump-ratio partitioning, " + " fluorescence yields, and line branching. A first-order secondary " + " fluorescence pass is then added from re-absorption of the primary " + " fluorescent lines inside the sample. This is a screening estimate, " + " not a full Monte Carlo transport calculation.\n\n" + "Key assumptions in the current implementation:\n" + "- the beam profile is uniform\n" + "- the beam footprint is centered on the capillary\n" + "- cylindrical capillaries are treated as transmission through a round " + "cross-section, so the footprint width controls the path-length average\n" + "- fluorescence escape is modeled with a first-order self-absorption / " + "secondary-emission approximation\n\n" + "References:\n" + f"- Hajizadeh et al. (2018): {SOLUTE_VOLUME_FRACTION_CITATION_URL}\n" + f"- Hubbell / NIST attenuation reference: {NIST_ATTENUATION_REFERENCE_URL}\n" + f"- XrayDB Python reference and Elam-based atomic data: {XRAYDB_REFERENCE_URL}\n" + f"- Roter et al. XRF forward-model discussion and jump-ratio considerations: {XRF_FORWARD_MODEL_REFERENCE_URL}\n" + f"- Trevorah et al. self-absorption discussion: {SELF_ABSORPTION_REFERENCE_URL}" +) + +SOLUTE_VOLUME_FRACTION_HELP_TEXT = SOLUTION_SCATTERING_HELP_TEXT + + +class BeamEnergyWavelengthDialog(QDialog): + def __init__( + self, + parent: QWidget | None = None, + *, + energy_kev: float, + ) -> None: + super().__init__(parent) + self.setWindowTitle("Beam Energy and Wavelength") + layout = QVBoxLayout(self) + explanation = QLabel( + "The X-ray wavelength is computed from the incident energy using " + "lambda (Å) = 12.3984198433 / E(keV)." + ) + explanation.setWordWrap(True) + layout.addWidget(explanation) + + form_layout = QFormLayout() + self.energy_value_label = QLabel() + self.wavelength_value_label = QLabel() + form_layout.addRow("Energy (keV)", self.energy_value_label) + form_layout.addRow("Wavelength (Å)", self.wavelength_value_label) + layout.addLayout(form_layout) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Close) + buttons.rejected.connect(self.close) + buttons.accepted.connect(self.close) + layout.addWidget(buttons) + + self.set_energy_kev(energy_kev) + self.resize(340, 140) + + def set_energy_kev(self, energy_kev: float) -> None: + wavelength = wavelength_angstrom_from_energy_kev(energy_kev) + self.energy_value_label.setText(f"{float(energy_kev):.12g}") + self.wavelength_value_label.setText(f"{float(wavelength):.12g}") + + +class SolutionScatteringEstimatorWidget(QWidget): + estimate_calculated = Signal(object) + estimate_failed = Signal(str) + + def __init__( + self, + parent: QWidget | None = None, + *, + default_number_density: bool = True, + default_volume_fraction: bool = True, + default_attenuation: bool = True, + default_fluorescence: bool = False, + ) -> None: + super().__init__(parent) + self._solution_presets: dict[str, SolutionPropertiesPreset] = {} + self._beam_presets: dict[str, BeamGeometryPreset] = {} + self._updating_solution_preset_selection = False + self._updating_beam_preset_selection = False + self._current_estimate: SolutionScatteringEstimate | None = None + self._wavelength_dialog: BeamEnergyWavelengthDialog | None = None + self._default_number_density = bool(default_number_density) + self._default_volume_fraction = bool(default_volume_fraction) + self._default_attenuation = bool(default_attenuation) + self._default_fluorescence = bool(default_fluorescence) + self._build_ui() + self._reload_solution_presets() + self._reload_beam_presets( + selected_name=DEFAULT_BEAM_GEOMETRY_PRESET_NAME + ) + self._select_beam_preset_name(DEFAULT_BEAM_GEOMETRY_PRESET_NAME) + self._load_selected_beam_preset() + + def _build_ui(self) -> None: + root_layout = QVBoxLayout(self) + root_layout.setContentsMargins(0, 0, 0, 0) + root_layout.setSpacing(8) + + controls_widget = QWidget(self) + layout = QVBoxLayout(controls_widget) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(8) + + self.target_label = QLabel( + "This calculator is not currently linked to an automatic Prefit " + "parameter update." + ) + self.target_label.setWordWrap(True) + layout.addWidget(self.target_label) + + preset_group = QGroupBox("Solution Presets") + preset_layout = QVBoxLayout(preset_group) + preset_row = QHBoxLayout() + self.solution_preset_combo = QComboBox() + preset_row.addWidget(self.solution_preset_combo, stretch=1) + self.load_solution_preset_button = QPushButton("Load") + self.load_solution_preset_button.clicked.connect( + self._load_selected_solution_preset + ) + preset_row.addWidget(self.load_solution_preset_button) + preset_layout.addLayout(preset_row) + preset_hint = QLabel( + "These presets populate the composition inputs. Density and " + "beam/capillary settings remain editable below." + ) + preset_hint.setWordWrap(True) + preset_layout.addWidget(preset_hint) + layout.addWidget(preset_group) + + self.solution_mode_combo = QComboBox() + for label, value in SOLUTION_MODE_ITEMS: + self.solution_mode_combo.addItem(label, userData=value) + self.solution_mode_combo.currentIndexChanged.connect( + self._on_solution_mode_changed + ) + self.solution_mode_combo.currentIndexChanged.connect( + self._on_solution_settings_changed + ) + + self.solution_density_spin = self._new_float_spin( + minimum=0.0, + maximum=100.0, + step=0.01, + decimals=6, + value=1.0, + ) + self.solution_density_spin.valueChanged.connect( + self._on_solution_settings_changed + ) + + self.solute_stoich_edit = QLineEdit() + self.solute_stoich_edit.setPlaceholderText("e.g. Cs1Pb1I3") + self.solute_stoich_edit.textChanged.connect( + self._on_solution_settings_changed + ) + + self.solvent_stoich_edit = QLineEdit() + self.solvent_stoich_edit.setPlaceholderText("e.g. H2O or C3H7NO") + self.solvent_stoich_edit.textChanged.connect( + self._on_solution_settings_changed + ) + + self.molar_mass_solute_spin = self._new_float_spin( + minimum=0.0, + maximum=1_000_000.0, + step=1.0, + decimals=6, + value=0.0, + ) + self.molar_mass_solute_spin.valueChanged.connect( + self._on_solution_settings_changed + ) + + self.molar_mass_solvent_spin = self._new_float_spin( + minimum=0.0, + maximum=1_000_000.0, + step=1.0, + decimals=6, + value=0.0, + ) + self.molar_mass_solvent_spin.valueChanged.connect( + self._on_solution_settings_changed + ) + + self.solute_density_spin = self._new_float_spin( + minimum=0.0, + maximum=100.0, + step=0.01, + decimals=6, + value=1.0, + ) + self.solute_density_spin.valueChanged.connect( + self._on_solution_settings_changed + ) + + self.solvent_density_spin = self._new_float_spin( + minimum=0.0, + maximum=100.0, + step=0.01, + decimals=6, + value=1.0, + ) + self.solvent_density_spin.valueChanged.connect( + self._on_solution_settings_changed + ) + + fields_layout = QGridLayout() + fields_layout.setColumnStretch(1, 1) + fields_layout.setColumnStretch(3, 1) + self.solution_mode_label = QLabel("Input mode") + fields_layout.addWidget(self.solution_mode_label, 0, 0) + fields_layout.addWidget(self.solution_mode_combo, 0, 1) + self.solution_density_label = QLabel("Solution density (g/mL)") + fields_layout.addWidget(self.solution_density_label, 0, 2) + fields_layout.addWidget(self.solution_density_spin, 0, 3) + self.solute_stoich_label = QLabel("Solute stoichiometry") + fields_layout.addWidget(self.solute_stoich_label, 1, 0) + fields_layout.addWidget(self.solute_stoich_edit, 1, 1) + self.solvent_stoich_label = QLabel("Solvent stoichiometry") + fields_layout.addWidget(self.solvent_stoich_label, 1, 2) + fields_layout.addWidget(self.solvent_stoich_edit, 1, 3) + self.molar_mass_solute_label = QLabel("Solute molar mass (g/mol)") + fields_layout.addWidget(self.molar_mass_solute_label, 2, 0) + fields_layout.addWidget(self.molar_mass_solute_spin, 2, 1) + self.molar_mass_solvent_label = QLabel("Solvent molar mass (g/mol)") + fields_layout.addWidget(self.molar_mass_solvent_label, 2, 2) + fields_layout.addWidget(self.molar_mass_solvent_spin, 2, 3) + self.solute_density_label = QLabel("Solute density (g/mL)") + fields_layout.addWidget(self.solute_density_label, 3, 0) + fields_layout.addWidget(self.solute_density_spin, 3, 1) + self.solvent_density_label = QLabel("Solvent density (g/mL)") + fields_layout.addWidget(self.solvent_density_label, 3, 2) + fields_layout.addWidget(self.solvent_density_spin, 3, 3) + layout.addLayout(fields_layout) + + self.solution_mode_hint_label = QLabel() + self.solution_mode_hint_label.setWordWrap(True) + layout.addWidget(self.solution_mode_hint_label) + + self.solution_mode_stack = QStackedWidget() + + mass_page = QWidget() + mass_form = QFormLayout(mass_page) + self.mass_solute_spin = self._new_float_spin( + minimum=0.0, + maximum=1_000_000.0, + step=0.1, + decimals=6, + value=0.0, + ) + self.mass_solute_spin.valueChanged.connect( + self._on_solution_settings_changed + ) + mass_form.addRow("Mass solute (g)", self.mass_solute_spin) + self.mass_solvent_spin = self._new_float_spin( + minimum=0.0, + maximum=1_000_000.0, + step=0.1, + decimals=6, + value=0.0, + ) + self.mass_solvent_spin.valueChanged.connect( + self._on_solution_settings_changed + ) + mass_form.addRow("Mass solvent (g)", self.mass_solvent_spin) + self.solution_mode_stack.addWidget(mass_page) + + mass_percent_page = QWidget() + mass_percent_form = QFormLayout(mass_percent_page) + self.mass_percent_solute_spin = self._new_float_spin( + minimum=0.0, + maximum=100.0, + step=0.1, + decimals=4, + value=0.0, + ) + self.mass_percent_solute_spin.valueChanged.connect( + self._on_solution_settings_changed + ) + mass_percent_form.addRow( + "Mass percent solute (%)", + self.mass_percent_solute_spin, + ) + self.total_mass_solution_spin = self._new_float_spin( + minimum=0.0, + maximum=1_000_000.0, + step=0.1, + decimals=6, + value=0.0, + ) + self.total_mass_solution_spin.valueChanged.connect( + self._on_solution_settings_changed + ) + mass_percent_form.addRow( + "Total solution mass (g)", + self.total_mass_solution_spin, + ) + self.solution_mode_stack.addWidget(mass_percent_page) + + molarity_page = QWidget() + molarity_form = QFormLayout(molarity_page) + self.molarity_spin = self._new_float_spin( + minimum=0.0, + maximum=100_000.0, + step=0.01, + decimals=6, + value=0.0, + ) + self.molarity_spin.valueChanged.connect( + self._on_solution_settings_changed + ) + molarity_form.addRow("Molarity (mol/L)", self.molarity_spin) + self.molarity_element_edit = QLineEdit() + self.molarity_element_edit.setPlaceholderText("e.g. Pb") + self.molarity_element_edit.textChanged.connect( + self._on_solution_settings_changed + ) + molarity_form.addRow( + "Molarity element", + self.molarity_element_edit, + ) + self.solution_mode_stack.addWidget(molarity_page) + layout.addWidget(self.solution_mode_stack) + + calculations_group = QGroupBox("Calculations") + calculations_layout = QHBoxLayout(calculations_group) + self.calculate_number_density_checkbox = QCheckBox("Number Density") + self.calculate_number_density_checkbox.setChecked( + self._default_number_density + ) + self.calculate_volume_fraction_checkbox = QCheckBox( + "Solute Volume Fraction" + ) + self.calculate_volume_fraction_checkbox.setChecked( + self._default_volume_fraction + ) + self.calculate_attenuation_checkbox = QCheckBox( + "Solvent Scattering Contribution" + ) + self.calculate_attenuation_checkbox.setChecked( + self._default_attenuation + ) + self.calculate_fluorescence_checkbox = QCheckBox( + "Sample Fluorescence Yield" + ) + self.calculate_fluorescence_checkbox.setChecked( + self._default_fluorescence + ) + for checkbox in ( + self.calculate_number_density_checkbox, + self.calculate_volume_fraction_checkbox, + self.calculate_attenuation_checkbox, + self.calculate_fluorescence_checkbox, + ): + checkbox.toggled.connect(self._on_solution_settings_changed) + calculations_layout.addWidget(checkbox) + calculations_layout.addStretch(1) + layout.addWidget(calculations_group) + + beam_preset_group = QGroupBox("Beam and Capillary Presets") + beam_preset_layout = QVBoxLayout(beam_preset_group) + beam_preset_row = QHBoxLayout() + self.beam_preset_combo = QComboBox() + beam_preset_row.addWidget(self.beam_preset_combo, stretch=1) + self.load_beam_preset_button = QPushButton("Load") + self.load_beam_preset_button.clicked.connect( + self._load_selected_beam_preset + ) + beam_preset_row.addWidget(self.load_beam_preset_button) + self.save_beam_preset_button = QPushButton("Save Current") + self.save_beam_preset_button.clicked.connect( + self._save_current_beam_preset + ) + beam_preset_row.addWidget(self.save_beam_preset_button) + self.delete_beam_preset_button = QPushButton("Delete") + self.delete_beam_preset_button.clicked.connect( + self._delete_selected_beam_preset + ) + beam_preset_row.addWidget(self.delete_beam_preset_button) + beam_preset_layout.addLayout(beam_preset_row) + beam_preset_hint = QLabel( + "Presets include beam energy, capillary size, geometry, beam " + "profile, and beam footprint. Custom presets are saved to a " + "JSON file for reuse." + ) + beam_preset_hint.setWordWrap(True) + beam_preset_layout.addWidget(beam_preset_hint) + layout.addWidget(beam_preset_group) + + beam_group = QGroupBox("Beam and Capillary") + beam_layout = QGridLayout(beam_group) + beam_layout.setColumnStretch(1, 1) + beam_layout.setColumnStretch(3, 1) + self.incident_energy_spin = self._new_float_spin( + minimum=0.0, + maximum=200.0, + step=0.1, + decimals=6, + value=DEFAULT_INCIDENT_ENERGY_KEV, + ) + self.incident_energy_spin.valueChanged.connect( + self._on_beam_settings_changed + ) + self.incident_energy_spin.valueChanged.connect( + self._update_wavelength_dialog_energy + ) + beam_layout.addWidget(QLabel("X-ray energy (keV)"), 0, 0) + incident_energy_row = QHBoxLayout() + incident_energy_row.addWidget(self.incident_energy_spin, stretch=1) + self.show_wavelength_button = QPushButton("Wavelength...") + self.show_wavelength_button.clicked.connect( + self._show_wavelength_dialog + ) + incident_energy_row.addWidget(self.show_wavelength_button) + beam_layout.addLayout(incident_energy_row, 0, 1) + self.capillary_size_spin = self._new_float_spin( + minimum=0.0, + maximum=100.0, + step=0.1, + decimals=6, + value=DEFAULT_CAPILLARY_SIZE_MM, + ) + self.capillary_size_spin.valueChanged.connect( + self._on_beam_settings_changed + ) + beam_layout.addWidget(QLabel("Capillary size (mm)"), 0, 2) + beam_layout.addWidget(self.capillary_size_spin, 0, 3) + self.capillary_geometry_combo = QComboBox() + for label, value in CAPILLARY_GEOMETRY_ITEMS: + self.capillary_geometry_combo.addItem(label, userData=value) + self.capillary_geometry_combo.currentIndexChanged.connect( + self._on_beam_settings_changed + ) + beam_layout.addWidget(QLabel("Capillary geometry"), 1, 0) + beam_layout.addWidget(self.capillary_geometry_combo, 1, 1) + self.beam_profile_combo = QComboBox() + for label, value in BEAM_PROFILE_ITEMS: + self.beam_profile_combo.addItem(label, userData=value) + self.beam_profile_combo.currentIndexChanged.connect( + self._on_beam_settings_changed + ) + beam_layout.addWidget(QLabel("Beam profile"), 1, 2) + beam_layout.addWidget(self.beam_profile_combo, 1, 3) + self.beam_footprint_width_spin = self._new_float_spin( + minimum=0.0, + maximum=100.0, + step=0.1, + decimals=6, + value=DEFAULT_BEAM_FOOTPRINT_WIDTH_MM, + ) + self.beam_footprint_width_spin.valueChanged.connect( + self._on_beam_settings_changed + ) + beam_layout.addWidget(QLabel("Beam footprint width (mm)"), 2, 0) + beam_layout.addWidget(self.beam_footprint_width_spin, 2, 1) + self.beam_footprint_height_spin = self._new_float_spin( + minimum=0.0, + maximum=100.0, + step=0.1, + decimals=6, + value=DEFAULT_BEAM_FOOTPRINT_HEIGHT_MM, + ) + self.beam_footprint_height_spin.valueChanged.connect( + self._on_beam_settings_changed + ) + beam_layout.addWidget(QLabel("Beam footprint height (mm)"), 2, 2) + beam_layout.addWidget(self.beam_footprint_height_spin, 2, 3) + layout.addWidget(beam_group) + + button_row = QHBoxLayout() + self.calculate_button = QPushButton("Run Selected Calculations") + self.calculate_button.clicked.connect(self._calculate_estimate) + button_row.addWidget(self.calculate_button) + button_row.addStretch(1) + self.output_toggle_button = QToolButton() + self.output_toggle_button.setToolButtonStyle( + Qt.ToolButtonStyle.ToolButtonTextBesideIcon + ) + self.output_toggle_button.setAutoRaise(True) + self.output_toggle_button.clicked.connect( + self._toggle_output_collapsed + ) + button_row.addWidget(self.output_toggle_button) + layout.addLayout(button_row) + + self.output_panel = QWidget(controls_widget) + output_layout = QVBoxLayout(self.output_panel) + output_layout.setContentsMargins(0, 0, 0, 0) + output_layout.setSpacing(8) + self.output_title_label = QLabel("Calculation Output") + output_layout.addWidget(self.output_title_label) + self.output_box = QTextEdit() + self.output_box.setReadOnly(True) + self.output_box.setMinimumHeight(220) + self.output_box.setPlaceholderText( + "Run the selected calculations to populate this pane." + ) + output_layout.addWidget(self.output_box, stretch=1) + layout.addWidget(self.output_panel) + layout.addStretch(1) + + controls_scroll = QScrollArea(self) + controls_scroll.setWidgetResizable(True) + controls_scroll.setFrameShape(QFrame.Shape.NoFrame) + controls_scroll.setWidget(controls_widget) + root_layout.addWidget(controls_scroll) + + self.set_output_collapsed(True) + self._update_solution_mode_widgets() + + @staticmethod + def _new_float_spin( + *, + minimum: float, + maximum: float, + step: float, + decimals: int, + value: float, + ) -> QDoubleSpinBox: + spin = QDoubleSpinBox() + spin.setRange(float(minimum), float(maximum)) + spin.setDecimals(decimals) + spin.setSingleStep(step) + spin.setValue(value) + return spin + + def set_target_parameter( + self, + parameter_name: str | None, + fraction_kind: str | None, + solvent_weight_parameter: str | None = None, + ) -> None: + messages: list[str] = [] + if parameter_name and fraction_kind: + label = "solute" if fraction_kind == "solute" else "solvent" + messages.append( + f"{parameter_name} ({label} SAXS-effective interaction fraction)" + ) + if solvent_weight_parameter: + if parameter_name and fraction_kind: + messages.append( + f"{solvent_weight_parameter} (attenuation solvent scale)" + ) + else: + messages.append( + f"{solvent_weight_parameter} (combined solvent background multiplier)" + ) + if messages: + self.target_label.setText( + "Active Prefit targets: " + "; ".join(messages) + "." + ) + else: + self.target_label.setText( + "This calculator is not currently linked to an automatic " + "Prefit parameter update." + ) + + def set_calculation_selection( + self, + *, + number_density: bool | None = None, + volume_fraction: bool | None = None, + attenuation: bool | None = None, + fluorescence: bool | None = None, + ) -> None: + if number_density is not None: + self.calculate_number_density_checkbox.setChecked( + bool(number_density) + ) + if volume_fraction is not None: + self.calculate_volume_fraction_checkbox.setChecked( + bool(volume_fraction) + ) + if attenuation is not None: + self.calculate_attenuation_checkbox.setChecked(bool(attenuation)) + if fluorescence is not None: + self.calculate_fluorescence_checkbox.setChecked(bool(fluorescence)) + + def append_application_note(self, message: str) -> None: + self.set_output_collapsed(False) + text = self.output_box.toPlainText().strip() + if not text: + self.output_box.setPlainText(message.strip()) + return + self.output_box.setPlainText(text + "\n\n" + message.strip()) + + def current_estimate(self) -> SolutionScatteringEstimate | None: + return self._current_estimate + + def _reload_solution_presets( + self, + *, + selected_name: str | None = None, + ) -> None: + previous_name = selected_name or self._selected_solution_preset_name() + self._solution_presets = load_solution_property_presets() + self.solution_preset_combo.blockSignals(True) + self.solution_preset_combo.clear() + self.solution_preset_combo.addItem("Current values", None) + selected_index = 0 + for index, name in enumerate( + ordered_solution_property_preset_names(self._solution_presets), + start=1, + ): + preset = self._solution_presets[name] + label = f"{name} (Built-in)" if preset.builtin else name + self.solution_preset_combo.addItem(label, name) + if name == previous_name: + selected_index = index + self.solution_preset_combo.setCurrentIndex(selected_index) + self.solution_preset_combo.blockSignals(False) + + def _selected_solution_preset_name(self) -> str | None: + payload = self.solution_preset_combo.currentData() + if payload is None: + return None + return str(payload) + + def _load_selected_solution_preset(self) -> None: + preset_name = self._selected_solution_preset_name() + if preset_name is None: + self.append_application_note("Select a solution preset to load.") + return + preset = self._solution_presets.get(preset_name) + if preset is None: + self.append_application_note( + f"Unknown solution preset: {preset_name}" + ) + return + self._apply_solution_preset(preset) + self._select_solution_preset_name(preset.name) + + def _apply_solution_preset( + self, + preset: SolutionPropertiesPreset, + ) -> None: + self._apply_solution_settings(preset.settings) + if preset.solute_density_g_per_ml is not None: + self.solute_density_spin.setValue(preset.solute_density_g_per_ml) + if preset.solvent_density_g_per_ml is not None: + self.solvent_density_spin.setValue(preset.solvent_density_g_per_ml) + + def _apply_solution_settings( + self, + settings: SolutionPropertiesSettings, + ) -> None: + previous_updating = self._updating_solution_preset_selection + self._updating_solution_preset_selection = True + try: + self._set_combo_value(self.solution_mode_combo, settings.mode) + self.solution_density_spin.setValue(settings.solution_density) + self.solute_stoich_edit.setText(settings.solute_stoich) + self.solvent_stoich_edit.setText(settings.solvent_stoich) + self.molar_mass_solute_spin.setValue(settings.molar_mass_solute) + self.molar_mass_solvent_spin.setValue(settings.molar_mass_solvent) + self.mass_solute_spin.setValue(settings.mass_solute) + self.mass_solvent_spin.setValue(settings.mass_solvent) + self.mass_percent_solute_spin.setValue( + settings.mass_percent_solute + ) + self.total_mass_solution_spin.setValue( + settings.total_mass_solution + ) + self.molarity_spin.setValue(settings.molarity) + self.molarity_element_edit.setText(settings.molarity_element) + self._update_solution_mode_widgets() + finally: + self._updating_solution_preset_selection = previous_updating + + def _reload_beam_presets( + self, + *, + selected_name: str | None = None, + ) -> None: + previous_name = selected_name or self._selected_beam_preset_name() + self._beam_presets = load_beam_geometry_presets() + self.beam_preset_combo.blockSignals(True) + self.beam_preset_combo.clear() + self.beam_preset_combo.addItem("Current values", None) + selected_index = 0 + for index, name in enumerate( + ordered_beam_geometry_preset_names(self._beam_presets), + start=1, + ): + preset = self._beam_presets[name] + label = f"{name} (Built-in)" if preset.builtin else name + self.beam_preset_combo.addItem(label, name) + if name == previous_name: + selected_index = index + self.beam_preset_combo.setCurrentIndex(selected_index) + self.beam_preset_combo.blockSignals(False) + + def _selected_beam_preset_name(self) -> str | None: + payload = self.beam_preset_combo.currentData() + if payload is None: + return None + return str(payload) + + def _load_selected_beam_preset(self) -> None: + preset_name = self._selected_beam_preset_name() + if preset_name is None: + self.append_application_note("Select a beam preset to load.") + return + preset = self._beam_presets.get(preset_name) + if preset is None: + self.append_application_note(f"Unknown beam preset: {preset_name}") + return + self._apply_beam_preset(preset) + self._select_beam_preset_name(preset.name) + + def _save_current_beam_preset(self) -> None: + suggested_name = self._selected_beam_preset_name() or "" + preset_name, accepted = QInputDialog.getText( + self, + "Save Beam Preset", + ( + "Enter a name for this beam/capillary preset.\n" + "The preset stores energy, capillary size, geometry, beam " + "profile, and footprint." + ), + text=suggested_name, + ) + if not accepted: + return + normalized_name = preset_name.strip() + if not normalized_name: + self.append_application_note("Beam preset name cannot be empty.") + return + preset = BeamGeometryPreset( + name=normalized_name, + beam=self._current_beam_settings(), + ) + save_custom_beam_geometry_preset(preset) + self._reload_beam_presets(selected_name=normalized_name) + self._select_beam_preset_name(normalized_name) + self.append_application_note(f"Saved beam preset {normalized_name!r}.") + + def _delete_selected_beam_preset(self) -> None: + preset_name = self._selected_beam_preset_name() + if preset_name is None: + self.append_application_note( + "Select a custom beam preset to delete." + ) + return + preset = self._beam_presets.get(preset_name) + if preset is None: + self.append_application_note(f"Unknown beam preset: {preset_name}") + return + if preset.builtin: + QMessageBox.information( + self, + "Built-in beam preset", + ( + "Built-in beam presets cannot be deleted. Save a custom " + "preset with the same name if you want to override it." + ), + ) + return + if not delete_custom_beam_geometry_preset(preset_name): + self.append_application_note( + f"No custom beam preset named {preset_name!r} was found." + ) + return + self._reload_beam_presets( + selected_name=DEFAULT_BEAM_GEOMETRY_PRESET_NAME + ) + self.append_application_note(f"Deleted beam preset {preset_name!r}.") + + def _apply_beam_preset(self, preset: BeamGeometryPreset) -> None: + self._apply_beam_settings(preset.beam) + + def _apply_beam_settings(self, settings: BeamGeometrySettings) -> None: + previous_updating = self._updating_beam_preset_selection + self._updating_beam_preset_selection = True + try: + self.incident_energy_spin.setValue(settings.incident_energy_kev) + self.capillary_size_spin.setValue(settings.capillary_size_mm) + self._set_combo_value( + self.capillary_geometry_combo, + settings.capillary_geometry, + ) + self._set_combo_value( + self.beam_profile_combo, + settings.beam_profile, + ) + self.beam_footprint_width_spin.setValue( + settings.beam_footprint_width_mm + ) + self.beam_footprint_height_spin.setValue( + settings.beam_footprint_height_mm + ) + finally: + self._updating_beam_preset_selection = previous_updating + self._update_wavelength_dialog_energy() + + def _current_beam_settings(self) -> BeamGeometrySettings: + return BeamGeometrySettings( + incident_energy_kev=float(self.incident_energy_spin.value()), + capillary_size_mm=float(self.capillary_size_spin.value()), + capillary_geometry=str( + self.capillary_geometry_combo.currentData() + or DEFAULT_CAPILLARY_GEOMETRY + ), + beam_profile=str( + self.beam_profile_combo.currentData() or DEFAULT_BEAM_PROFILE + ), + beam_footprint_width_mm=float( + self.beam_footprint_width_spin.value() + ), + beam_footprint_height_mm=float( + self.beam_footprint_height_spin.value() + ), + ) + + @staticmethod + def _set_combo_value(combo: QComboBox, value: str) -> None: + index = combo.findData(value) + if index >= 0: + combo.setCurrentIndex(index) + + def _selected_solution_mode(self) -> str: + return str(self.solution_mode_combo.currentData() or "mass") + + def _current_solution_settings(self) -> SolutionPropertiesSettings: + return SolutionPropertiesSettings( + mode=self._selected_solution_mode(), + solution_density=float(self.solution_density_spin.value()), + solute_stoich=self.solute_stoich_edit.text().strip(), + solvent_stoich=self.solvent_stoich_edit.text().strip(), + molar_mass_solute=float(self.molar_mass_solute_spin.value()), + molar_mass_solvent=float(self.molar_mass_solvent_spin.value()), + mass_solute=float(self.mass_solute_spin.value()), + mass_solvent=float(self.mass_solvent_spin.value()), + mass_percent_solute=float(self.mass_percent_solute_spin.value()), + total_mass_solution=float(self.total_mass_solution_spin.value()), + molarity=float(self.molarity_spin.value()), + molarity_element=self.molarity_element_edit.text().strip(), + ) + + def current_estimator_settings( + self, + ) -> SolutionScatteringEstimatorSettings: + mode = self._selected_solution_mode() + return SolutionScatteringEstimatorSettings( + solution=self._current_solution_settings(), + solute_density_g_per_ml=( + None + if mode == "molarity_per_liter" + else float(self.solute_density_spin.value()) + ), + solvent_density_g_per_ml=float(self.solvent_density_spin.value()), + calculate_number_density=( + self.calculate_number_density_checkbox.isChecked() + ), + calculate_solute_volume_fraction=( + self.calculate_volume_fraction_checkbox.isChecked() + ), + calculate_solvent_scattering_contribution=( + self.calculate_attenuation_checkbox.isChecked() + ), + calculate_sample_fluorescence_yield=( + self.calculate_fluorescence_checkbox.isChecked() + ), + beam=self._current_beam_settings(), + ) + + def _on_solution_mode_changed(self) -> None: + self._update_solution_mode_widgets() + + def _update_solution_mode_widgets(self) -> None: + selected_mode = self._selected_solution_mode() + mode_to_index = { + "mass": 0, + "mass_percent": 1, + "molarity_per_liter": 2, + } + self.solution_mode_stack.setCurrentIndex( + mode_to_index.get(selected_mode, 0) + ) + show_solute_density = selected_mode != "molarity_per_liter" + self.solute_density_label.setVisible(show_solute_density) + self.solute_density_spin.setVisible(show_solute_density) + self.solvent_density_label.setVisible(True) + self.solvent_density_spin.setVisible(True) + self.solution_mode_hint_label.setText( + self._estimator_mode_hint_text(selected_mode) + ) + + @staticmethod + def _estimator_mode_hint_text(mode: str) -> str: + base = solution_properties_mode_hint_text(mode) + if mode == "molarity_per_liter": + return ( + f"{base} In molarity mode the solute density is hidden, but " + "the solvent density remains active for attenuation and " + "volume-closure calculations." + ) + return ( + f"{base} In these modes, both component densities remain " + "available for the volume-fraction and attenuation estimates." + ) + + def _on_solution_settings_changed(self, *_args: object) -> None: + if self._updating_solution_preset_selection: + return + self._select_solution_preset_name( + self._matching_solution_preset_name( + self._current_solution_settings() + ) + ) + + def _on_beam_settings_changed(self, *_args: object) -> None: + if self._updating_beam_preset_selection: + return + self._select_beam_preset_name( + self._matching_beam_preset_name(self._current_beam_settings()) + ) + + def _select_solution_preset_name(self, preset_name: str | None) -> None: + target_index = 0 + if preset_name is not None: + for index in range(self.solution_preset_combo.count()): + if self.solution_preset_combo.itemData(index) == preset_name: + target_index = index + break + previous_updating = self._updating_solution_preset_selection + self._updating_solution_preset_selection = True + try: + self.solution_preset_combo.setCurrentIndex(target_index) + finally: + self._updating_solution_preset_selection = previous_updating + + def _select_beam_preset_name(self, preset_name: str | None) -> None: + target_index = 0 + if preset_name is not None: + for index in range(self.beam_preset_combo.count()): + if self.beam_preset_combo.itemData(index) == preset_name: + target_index = index + break + previous_updating = self._updating_beam_preset_selection + self._updating_beam_preset_selection = True + try: + self.beam_preset_combo.setCurrentIndex(target_index) + finally: + self._updating_beam_preset_selection = previous_updating + + def _matching_solution_preset_name( + self, + settings: SolutionPropertiesSettings, + ) -> str | None: + for name in ordered_solution_property_preset_names( + self._solution_presets + ): + preset = self._solution_presets.get(name) + if preset is None: + continue + if self._solution_settings_match(settings, preset.settings): + return name + return None + + def _matching_beam_preset_name( + self, + settings: BeamGeometrySettings, + ) -> str | None: + for name in ordered_beam_geometry_preset_names(self._beam_presets): + preset = self._beam_presets.get(name) + if preset is None: + continue + if self._beam_settings_match(settings, preset.beam): + return name + return None + + @staticmethod + def _solution_settings_match( + left: SolutionPropertiesSettings, + right: SolutionPropertiesSettings, + ) -> bool: + float_fields = ( + "solution_density", + "molar_mass_solute", + "molar_mass_solvent", + "mass_solute", + "mass_solvent", + "mass_percent_solute", + "total_mass_solution", + "molarity", + ) + text_fields = ( + "mode", + "solute_stoich", + "solvent_stoich", + "molarity_element", + ) + for field_name in float_fields: + if ( + abs( + float(getattr(left, field_name)) + - float(getattr(right, field_name)) + ) + > 1e-9 + ): + return False + for field_name in text_fields: + if str(getattr(left, field_name)) != str( + getattr(right, field_name) + ): + return False + return True + + @staticmethod + def _beam_settings_match( + left: BeamGeometrySettings, + right: BeamGeometrySettings, + ) -> bool: + float_fields = ( + "incident_energy_kev", + "capillary_size_mm", + "beam_footprint_width_mm", + "beam_footprint_height_mm", + ) + text_fields = ( + "capillary_geometry", + "beam_profile", + ) + for field_name in float_fields: + if ( + abs( + float(getattr(left, field_name)) + - float(getattr(right, field_name)) + ) + > 1e-9 + ): + return False + for field_name in text_fields: + if str(getattr(left, field_name)) != str( + getattr(right, field_name) + ): + return False + return True + + def _show_wavelength_dialog(self) -> None: + if self._wavelength_dialog is None: + dialog = BeamEnergyWavelengthDialog( + self, + energy_kev=float(self.incident_energy_spin.value()), + ) + dialog.destroyed.connect(self._clear_wavelength_dialog) + self._wavelength_dialog = dialog + self._update_wavelength_dialog_energy() + self._wavelength_dialog.show() + self._wavelength_dialog.raise_() + self._wavelength_dialog.activateWindow() + + def _update_wavelength_dialog_energy(self, *_args: object) -> None: + if self._wavelength_dialog is None: + return + self._wavelength_dialog.set_energy_kev( + float(self.incident_energy_spin.value()) + ) + + def _clear_wavelength_dialog(self, *_args: object) -> None: + self._wavelength_dialog = None + + def _calculate_estimate(self) -> None: + settings = self.current_estimator_settings() + if not ( + settings.calculate_number_density + or settings.calculate_solute_volume_fraction + or settings.calculate_solvent_scattering_contribution + or settings.calculate_sample_fluorescence_yield + ): + message = "Select at least one calculation before running." + self.set_output_collapsed(False) + self.output_box.setPlainText(message) + self._current_estimate = None + self.estimate_failed.emit(message) + return + try: + estimate = calculate_solution_scattering_estimate(settings) + except Exception as exc: + message = f"Unable to run the selected calculations: {exc}" + self.set_output_collapsed(False) + self.output_box.setPlainText(message) + self._current_estimate = None + self.estimate_failed.emit(str(exc)) + return + self._current_estimate = estimate + self.set_output_collapsed(False) + self.output_box.setPlainText(estimate.summary_text()) + self.estimate_calculated.emit(estimate) + + def output_is_collapsed(self) -> bool: + return self.output_panel.isHidden() + + def set_output_collapsed(self, collapsed: bool) -> None: + is_collapsed = bool(collapsed) + if is_collapsed: + self.output_panel.hide() + else: + self.output_panel.show() + self.output_toggle_button.setArrowType( + Qt.ArrowType.DownArrow if is_collapsed else Qt.ArrowType.UpArrow + ) + self.output_toggle_button.setText( + "Show Output" if is_collapsed else "Hide Output" + ) + + def _toggle_output_collapsed(self) -> None: + self.set_output_collapsed(not self.output_is_collapsed()) + + +class SolutionScatteringToolWindow(QWidget): + def __init__( + self, + title: str, + subtitle_html: str, + *, + default_number_density: bool, + default_volume_fraction: bool, + default_attenuation: bool, + default_fluorescence: bool, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self.setWindowTitle(title) + layout = QVBoxLayout(self) + citation_label = QLabel(subtitle_html) + citation_label.setWordWrap(True) + citation_label.setOpenExternalLinks(True) + layout.addWidget(citation_label) + self.estimator_widget = SolutionScatteringEstimatorWidget( + self, + default_number_density=default_number_density, + default_volume_fraction=default_volume_fraction, + default_attenuation=default_attenuation, + default_fluorescence=default_fluorescence, + ) + layout.addWidget(self.estimator_widget) + self.resize(self._default_window_size()) + + @staticmethod + def _default_window_size() -> QSize: + app = QApplication.instance() + screen = app.primaryScreen() if app is not None else None + if screen is None: + return QSize(1100, 720) + + available = screen.availableGeometry() + target_width = min(1120, max(960, available.width() - 180)) + target_height = min(760, max(640, available.height() - 180)) + target_width = min(target_width, max(760, available.width() - 48)) + target_height = min(target_height, max(560, available.height() - 72)) + return QSize(target_width, target_height) + + +class SoluteVolumeFractionToolWindow(SolutionScatteringToolWindow): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__( + "Volume Fraction Estimate", + ( + "Estimate solution-scattering quantities from composition, " + "density, and beam/capillary settings. " + f'' + "Hajizadeh et al. (2018)" + ), + default_number_density=True, + default_volume_fraction=True, + default_attenuation=False, + default_fluorescence=False, + parent=parent, + ) + + +class AttenuationEstimateToolWindow(SolutionScatteringToolWindow): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__( + "Attenuation Estimate", + ( + "Estimate attenuation and the solvent scattering scale factor " + "used to map a neat-solvent reference onto the solvent " + "contribution inside the sample. " + f'NIST, ' + f'XrayDB' + ), + default_number_density=False, + default_volume_fraction=True, + default_attenuation=True, + default_fluorescence=False, + parent=parent, + ) + + +class FluorescenceEstimateToolWindow(SolutionScatteringToolWindow): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__( + "Fluorescence Estimate", + ( + "Estimate fluorescence background tendencies from the sample " + "composition, attenuation, and beam energy using a " + "first-order primary plus secondary fluorescence model. " + f'XRF forward model, ' + f'self-absorption' + ), + default_number_density=False, + default_volume_fraction=False, + default_attenuation=False, + default_fluorescence=True, + parent=parent, + ) + + +class NumberDensityEstimateToolWindow(SolutionScatteringToolWindow): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__( + "Number Density Estimate", + ( + "Estimate the total atomic number density of the solution " + "from the composition inputs and report it in atoms/A^3." + ), + default_number_density=True, + default_volume_fraction=False, + default_attenuation=False, + default_fluorescence=False, + parent=parent, + ) + + +SoluteVolumeFractionWidget = SolutionScatteringEstimatorWidget + +__all__ = [ + "AttenuationEstimateToolWindow", + "FluorescenceEstimateToolWindow", + "NumberDensityEstimateToolWindow", + "SOLUTE_VOLUME_FRACTION_CITATION_URL", + "SOLUTE_VOLUME_FRACTION_HELP_TEXT", + "SOLUTION_SCATTERING_HELP_TEXT", + "SolutionScatteringEstimatorWidget", + "SolutionScatteringToolWindow", + "SoluteVolumeFractionToolWindow", + "SoluteVolumeFractionWidget", +] diff --git a/src/saxshell/saxshell.py b/src/saxshell/saxshell.py index b151ef1..9b95d69 100644 --- a/src/saxshell/saxshell.py +++ b/src/saxshell/saxshell.py @@ -56,6 +56,30 @@ def main(argv: list[str] | None = None) -> int: nargs=argparse.REMAINDER, help="Arguments passed through to the bondanalysis command.", ) + clusterdynamics_parser = subparsers.add_parser( + "clusterdynamics", + help=( + "Analyze time-binned cluster distributions and lifetimes, or " + "launch the cluster-dynamics UI." + ), + ) + clusterdynamics_parser.add_argument( + "args", + nargs=argparse.REMAINDER, + help="Arguments passed through to the clusterdynamics command.", + ) + clusterdynamicsml_parser = subparsers.add_parser( + "clusterdynamicsml", + help=( + "Predict larger-cluster surrogate structures, stoichiometries, " + "and cluster-only SAXS traces." + ), + ) + clusterdynamicsml_parser.add_argument( + "args", + nargs=argparse.REMAINDER, + help="Arguments passed through to the clusterdynamicsml command.", + ) xyz2pdb_parser = subparsers.add_parser( "xyz2pdb", help=( @@ -119,6 +143,24 @@ def main(argv: list[str] | None = None) -> int: forwarded_args = forwarded_args[1:] return bondanalysis_main(forwarded_args) + if args.command == "clusterdynamics": + from saxshell.clusterdynamics.cli import main as clusterdynamics_main + + forwarded_args = list(args.args) + if forwarded_args[:1] == ["--"]: + forwarded_args = forwarded_args[1:] + return clusterdynamics_main(forwarded_args) + + if args.command == "clusterdynamicsml": + from saxshell.clusterdynamicsml.cli import ( + main as clusterdynamicsml_main, + ) + + forwarded_args = list(args.args) + if forwarded_args[:1] == ["--"]: + forwarded_args = forwarded_args[1:] + return clusterdynamicsml_main(forwarded_args) + if args.command == "xyz2pdb": from saxshell.xyz2pdb.cli import main as xyz2pdb_main diff --git a/tests/test_saxs_dream_runtime.py b/tests/test_saxs_dream_runtime.py index 81d957a..f4bf118 100644 --- a/tests/test_saxs_dream_runtime.py +++ b/tests/test_saxs_dream_runtime.py @@ -61,7 +61,6 @@ def _build_minimal_saxs_project(tmp_path): vol_frac=0.0, scale=5e-4, ) - experimental_path = paths.experimental_data_dir / "exp_demo.txt" np.savetxt(experimental_path, np.column_stack([q_values, experimental])) _write_component_file( @@ -152,7 +151,6 @@ def _build_poly_lma_geometry_project( offset=0.0, log_sigma=-9.21, ) - experimental_path = paths.experimental_data_dir / "exp_demo.txt" np.savetxt(experimental_path, np.column_stack([q_values, experimental])) _write_component_file( diff --git a/tests/test_saxs_model_report.py b/tests/test_saxs_model_report.py new file mode 100644 index 0000000..5c11022 --- /dev/null +++ b/tests/test_saxs_model_report.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import numpy as np +import pytest +from matplotlib.colors import to_hex +from matplotlib.figure import Figure + +from saxshell.saxs.model_report import ( + ReportComponentPlotData, + ReportComponentSeries, + _aligned_y_limits, + _autoscale_component_plot_to_model_range, + _draw_model_fit_axis, + _draw_prefit_plot_axes, +) +from saxshell.saxs.prefit import PrefitEvaluation +from saxshell.saxs.project_manager import PowerPointExportSettings + + +def test_component_plot_autoscale_to_model_range_matches_project_setup_behavior(): + figure = Figure() + experimental_axis = figure.add_subplot(111) + component_axis = experimental_axis.twinx() + + experimental_q = np.asarray([0.05, 0.10, 0.15, 0.20, 0.25], dtype=float) + experimental_i = np.asarray([200.0, 100.0, 80.0, 60.0, 40.0], dtype=float) + model_q = np.asarray([0.10, 0.15, 0.20], dtype=float) + model_i = np.asarray([2.0, 4.0, 8.0], dtype=float) + + experimental_axis.plot( + experimental_q, + experimental_i, + color="black", + alpha=0.35, + linewidth=1.3, + label="Experimental data", + ) + experimental_axis.plot( + model_q, + experimental_i[1:4], + color="black", + linewidth=1.8, + label="Selected q-range", + ) + component_axis.plot( + model_q, + model_i, + color="tab:blue", + linewidth=1.4, + label="component", + ) + + plot_data = ReportComponentPlotData( + title="Initial SAXS traces without solvent", + selected_q_min=0.10, + selected_q_max=0.20, + use_experimental_grid=False, + log_x=False, + log_y=False, + experimental_q_values=experimental_q, + experimental_intensities=experimental_i, + solvent_q_values=None, + solvent_intensities=None, + component_series=( + ReportComponentSeries( + label="component", + q_values=model_q, + intensities=model_i, + color="#1f77b4", + ), + ), + ) + + _autoscale_component_plot_to_model_range( + experimental_axis, + component_axis, + plot_data, + ) + + assert experimental_axis.get_xlim() == pytest.approx((0.10, 0.20)) + assert component_axis.get_xlim() == pytest.approx((0.10, 0.20)) + assert experimental_axis.get_ylim() == pytest.approx((58.0, 102.0)) + + expected_component_limits = _aligned_y_limits( + experimental_axis.get_ylim(), + 60.0, + 100.0, + 2.0, + 8.0, + log_scale=False, + ) + assert component_axis.get_ylim() == pytest.approx( + expected_component_limits + ) + + +def test_prefit_plot_axes_match_prefit_window_colors_and_rescale_without_solvent(): + figure_without = Figure() + grid_without = figure_without.add_gridspec(2, 1, height_ratios=[3, 1]) + top_without = figure_without.add_subplot(grid_without[0, 0]) + bottom_without = figure_without.add_subplot( + grid_without[1, 0], + sharex=top_without, + ) + + figure_with = Figure() + grid_with = figure_with.add_gridspec(2, 1, height_ratios=[3, 1]) + top_with = figure_with.add_subplot(grid_with[0, 0]) + bottom_with = figure_with.add_subplot( + grid_with[1, 0], + sharex=top_with, + ) + + q_values = np.asarray([0.05, 0.08, 0.12, 0.18, 0.24], dtype=float) + experimental = np.asarray([120.0, 100.0, 90.0, 70.0, 55.0], dtype=float) + model = np.asarray([110.0, 98.0, 88.0, 68.0, 52.0], dtype=float) + solvent = np.asarray([8.0, 12.0, 18.0, 380.0, 520.0], dtype=float) + evaluation = PrefitEvaluation( + q_values=q_values, + experimental_intensities=experimental, + model_intensities=model, + residuals=model - experimental, + solvent_contribution=solvent, + ) + settings = PowerPointExportSettings() + + _draw_prefit_plot_axes( + top_without, + bottom_without, + evaluation, + include_solvent=False, + settings=settings, + ) + _draw_prefit_plot_axes( + top_with, + bottom_with, + evaluation, + include_solvent=True, + settings=settings, + ) + + without_lines = { + str(line.get_label()): to_hex(line.get_color(), keep_alpha=False) + for line in top_without.get_lines() + } + with_lines = { + str(line.get_label()): to_hex(line.get_color(), keep_alpha=False) + for line in top_with.get_lines() + } + + assert without_lines == { + "Experimental": "#000000", + "Model": "#d62728", + } + assert with_lines == { + "Experimental": "#000000", + "Solvent contribution": "#008000", + "Model": "#d62728", + } + assert ( + to_hex(bottom_without.get_lines()[-1].get_color(), keep_alpha=False) + == "#1f77b4" + ) + assert ( + to_hex(bottom_with.get_lines()[-1].get_color(), keep_alpha=False) + == "#1f77b4" + ) + assert top_without.get_ylim()[1] < top_with.get_ylim()[1] + + +def test_filter_fit_axis_matches_dream_output_colors_and_hides_solvent(): + figure = Figure() + axis = figure.add_subplot(111) + settings = PowerPointExportSettings() + q_values = np.asarray([0.05, 0.08, 0.12, 0.18], dtype=float) + experimental = np.asarray([120.0, 100.0, 88.0, 70.0], dtype=float) + model = np.asarray([118.0, 97.0, 84.0, 68.0], dtype=float) + structure_factor = np.asarray([1.0, 0.96, 0.91, 0.88], dtype=float) + + _draw_model_fit_axis( + axis, + q_values=q_values, + experimental=experimental, + model=model, + solvent=None, + structure_factor=structure_factor, + title="All Post-burnin", + metrics_lines=["RMSE: 0.1"], + show_legend=False, + compact=True, + dream_output_style=True, + settings=settings, + ) + + assert len(figure.axes) == 2 + left_axis, right_axis = figure.axes + left_lines = { + str(line.get_label()): to_hex(line.get_color(), keep_alpha=False) + for line in left_axis.get_lines() + } + left_collections = [ + to_hex(collection.get_facecolor()[0], keep_alpha=False) + for collection in left_axis.collections + if collection.get_offsets().size > 0 + ] + + assert left_collections == ["#000000"] + assert left_lines == {"Model": "#d62728"} + assert "Solvent contribution" not in left_lines + assert right_axis.get_ylabel() == "S(q)" + assert ( + to_hex(right_axis.get_lines()[0].get_color(), keep_alpha=False) + == "#9467bd" + ) diff --git a/tests/test_saxs_prefit.py b/tests/test_saxs_prefit.py index 048de4c..6d95256 100644 --- a/tests/test_saxs_prefit.py +++ b/tests/test_saxs_prefit.py @@ -12,7 +12,9 @@ SAXSPrefitWorkflow, compute_cluster_geometry_metadata, load_cluster_geometry_metadata, + resolve_prefit_parameter_entries, ) +from saxshell.saxs.prefit.workflow import constrained_prefit_residuals from saxshell.saxs.project_manager import ( SAXSProjectManager, build_project_paths, @@ -21,6 +23,11 @@ SoluteVolumeFractionSettings, calculate_solute_volume_fraction_estimate, ) +from saxshell.saxs.solution_scattering_estimator import ( + BeamGeometrySettings, + SolutionScatteringEstimatorSettings, + calculate_solution_scattering_estimate, +) POLY_LMA_HS_TEMPLATE = "template_pydream_poly_lma_hs" POLY_LMA_HS_MIX_TEMPLATE = "template_pydream_poly_lma_hs_mix_approx" @@ -64,7 +71,6 @@ def _build_minimal_saxs_project(tmp_path): vol_frac=0.0, scale=5e-4, ) - experimental_path = paths.experimental_data_dir / "exp_demo.txt" np.savetxt( experimental_path, @@ -165,7 +171,6 @@ def _build_poly_lma_geometry_project( offset=0.0, log_sigma=-9.21, ) - experimental_path = paths.experimental_data_dir / "exp_demo.txt" np.savetxt( experimental_path, @@ -254,6 +259,230 @@ def test_saxs_prefit_workflow_recommends_scale_from_model_difference(tmp_path): assert recommendation.recommended_scale == pytest.approx(5e-4) assert recommendation.recommended_minimum == pytest.approx(5e-5) assert recommendation.recommended_maximum == pytest.approx(5e-3) + assert recommendation.current_offset == pytest.approx(0.0) + assert recommendation.recommended_offset == pytest.approx(0.05) + assert recommendation.points_used == 8 + + +def test_constrained_prefit_residuals_penalize_non_positive_model_values(): + experimental = np.asarray([1.0, 0.8, 0.6], dtype=float) + model = np.asarray([1.1, -0.2, np.nan], dtype=float) + + residuals = constrained_prefit_residuals(experimental, model) + + assert residuals.shape == (6,) + assert residuals[0] == pytest.approx(0.1) + assert residuals[3] == pytest.approx(0.0) + assert residuals[4] > 25.0 + assert residuals[5] > residuals[4] + + +def test_saxs_prefit_workflow_resolves_and_persists_linked_parameters( + tmp_path, +): + project_dir, paths = _build_minimal_saxs_project(tmp_path) + workflow = SAXSPrefitWorkflow(project_dir) + entries = workflow.load_parameter_entries() + scale_entry = next(entry for entry in entries if entry.name == "scale") + offset_entry = next(entry for entry in entries if entry.name == "offset") + scale_entry.value = 2.5e-4 + offset_entry.initial_value_expression = "*scale" + offset_entry.vary = True + + resolved_entries = resolve_prefit_parameter_entries(entries) + resolved_offset = next( + entry for entry in resolved_entries if entry.name == "offset" + ) + assert resolved_offset.initial_value_expression == "*scale" + assert resolved_offset.value_expression is None + assert resolved_offset.value == pytest.approx(2.5e-4) + assert resolved_offset.vary is True + + fit_result = workflow.run_fit( + resolved_entries, method="leastsq", max_nfev=50 + ) + fitted_offset = next( + entry + for entry in fit_result.parameter_entries + if entry.name == "offset" + ) + assert fitted_offset.initial_value_expression == "*scale" + assert fitted_offset.value_expression is None + assert fitted_offset.vary is True + + workflow.save_fit(fit_result.parameter_entries) + state_payload = json.loads( + (paths.prefit_dir / "prefit_state.json").read_text(encoding="utf-8") + ) + state_offset = next( + entry + for entry in state_payload["parameter_entries"] + if entry["name"] == "offset" + ) + assert state_offset["initial_value_expression"] == "*scale" + + prefit_payload = json.loads( + (paths.prefit_dir / "pd_prefit_params.json").read_text( + encoding="utf-8" + ) + ) + assert ( + prefit_payload["fit_parameter_meta"]["offset"]["initial_expression"] + == "*scale" + ) + + +def test_saxs_prefit_workflow_preserves_dependent_parameter_expressions( + tmp_path, +): + project_dir, paths = _build_minimal_saxs_project(tmp_path) + workflow = SAXSPrefitWorkflow(project_dir) + entries = workflow.load_parameter_entries() + scale_entry = next(entry for entry in entries if entry.name == "scale") + offset_entry = next(entry for entry in entries if entry.name == "offset") + scale_entry.value = 2.5e-4 + offset_entry.value_expression = "*scale" + offset_entry.vary = False + + resolved_entries = resolve_prefit_parameter_entries(entries) + resolved_offset = next( + entry for entry in resolved_entries if entry.name == "offset" + ) + assert resolved_offset.value_expression == "*scale" + assert resolved_offset.initial_value_expression is None + assert resolved_offset.value == pytest.approx(2.5e-4) + assert resolved_offset.vary is False + + fit_result = workflow.run_fit( + resolved_entries, + method="leastsq", + max_nfev=50, + ) + fitted_offset = next( + entry + for entry in fit_result.parameter_entries + if entry.name == "offset" + ) + assert fitted_offset.value_expression == "*scale" + assert fitted_offset.initial_value_expression is None + assert fitted_offset.vary is False + + workflow.save_fit(fit_result.parameter_entries) + state_payload = json.loads( + (paths.prefit_dir / "prefit_state.json").read_text(encoding="utf-8") + ) + state_offset = next( + entry + for entry in state_payload["parameter_entries"] + if entry["name"] == "offset" + ) + assert state_offset["value_expression"] == "*scale" + + +def test_saxs_prefit_workflow_rejects_autoscale_for_linked_scale_or_offset( + tmp_path, +): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + workflow = SAXSPrefitWorkflow(project_dir) + entries = workflow.load_parameter_entries() + offset_entry = next(entry for entry in entries if entry.name == "offset") + offset_entry.value_expression = "*scale" + + with pytest.raises(ValueError, match="offset is linked"): + workflow.recommend_scale_settings(entries) + + +def test_saxs_prefit_workflow_allows_autoscale_for_expression_seed_parameters( + tmp_path, +): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + workflow = SAXSPrefitWorkflow(project_dir) + entries = workflow.load_parameter_entries() + offset_entry = next(entry for entry in entries if entry.name == "offset") + offset_entry.initial_value_expression = "*scale" + offset_entry.vary = True + + recommendation = workflow.recommend_scale_settings(entries) + + assert recommendation.recommended_scale > 0.0 + + +def test_saxs_prefit_workflow_supports_model_only_evaluation(tmp_path): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.model_only_mode = True + settings.use_experimental_grid = False + settings.q_points = 8 + manager.save_project(settings) + + workflow = SAXSPrefitWorkflow(project_dir) + evaluation = workflow.evaluate() + + assert evaluation.experimental_intensities is None + assert evaluation.residuals is None + assert evaluation.is_model_only is True + assert workflow.can_run_prefit() is False + with pytest.raises(ValueError, match="Model Only Mode"): + workflow.run_fit(method="leastsq", max_nfev=100) + + +def test_saxs_prefit_workflow_recommends_scale_with_weighted_solvent_trace( + tmp_path, +): + project_dir, paths = _build_minimal_saxs_project(tmp_path) + solvent_q = np.linspace(0.05, 0.3, 8) + solvent_intensity = np.linspace(1.5, 2.2, 8) + solvent_path = tmp_path / "autoscale_solvent_trace.dat" + np.savetxt(solvent_path, np.column_stack([solvent_q, solvent_intensity])) + + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.solvent_data_path = str(solvent_path) + settings.copied_solvent_data_file = None + manager.save_project(settings) + + template_module = load_template_module( + "template_pd_likelihood_monosq_decoupled" + ) + q_values = np.linspace(0.05, 0.3, 8) + component = np.linspace(10.0, 17.0, 8) + experimental = template_module.lmfit_model_profile( + q_values, + solvent_intensity, + [component], + w0=0.6, + solv_w=0.5, + offset=0.05, + eff_r=9.0, + vol_frac=0.0, + scale=5e-4, + ) + experimental_path = paths.experimental_data_dir / "exp_demo.txt" + np.savetxt(experimental_path, np.column_stack([q_values, experimental])) + settings = manager.load_project(project_dir) + settings.experimental_data_path = str(experimental_path) + settings.copied_experimental_data_file = str(experimental_path) + manager.save_project(settings) + + workflow = SAXSPrefitWorkflow(project_dir) + entries = workflow.load_parameter_entries() + for entry in entries: + if entry.name == "solv_w": + entry.value = 0.5 + if entry.name == "scale": + entry.value = 1e-6 + entry.minimum = 1e-7 + entry.maximum = 1e-5 + + recommendation = workflow.recommend_scale_settings(entries) + + assert recommendation.current_scale == pytest.approx(1e-6) + assert recommendation.recommended_scale == pytest.approx(5e-4) + assert recommendation.recommended_minimum == pytest.approx(5e-5) + assert recommendation.recommended_maximum == pytest.approx(5e-3) + assert recommendation.current_offset == pytest.approx(0.0) + assert recommendation.recommended_offset == pytest.approx(0.05) assert recommendation.points_used == 8 @@ -351,6 +580,165 @@ def test_monosq_prefit_workflow_does_not_expose_solute_volume_fraction_target( assert workflow.volume_fraction_estimator_target() is None +def test_monosq_prefit_workflow_exposes_solvent_weight_target(tmp_path): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + workflow = SAXSPrefitWorkflow(project_dir) + + assert workflow.solvent_weight_estimator_target() == "solv_w" + + +def test_poly_lma_prefit_workflow_exposes_solvent_weight_target(tmp_path): + project_dir, _paths, _radius = _build_poly_lma_geometry_project(tmp_path) + workflow = SAXSPrefitWorkflow(project_dir) + + assert workflow.solvent_weight_estimator_target() == "solvent_scale" + + +def test_poly_lma_prefit_defaults_fix_solvent_subtraction_controls(tmp_path): + project_dir, _paths, _radius = _build_poly_lma_geometry_project(tmp_path) + workflow = SAXSPrefitWorkflow(project_dir) + entries_by_name = { + entry.name: entry for entry in workflow.load_template_reset_entries() + } + + assert entries_by_name["phi_solute"].vary is False + assert entries_by_name["solvent_scale"].vary is False + + +def test_poly_lma_prefit_rejects_redundant_solvent_subtraction_fit(tmp_path): + project_dir, _paths, _radius = _build_poly_lma_geometry_project(tmp_path) + workflow = SAXSPrefitWorkflow(project_dir) + entries = workflow.load_parameter_entries() + for entry in entries: + if entry.name in {"phi_solute", "solvent_scale"}: + entry.vary = True + + with pytest.raises( + ValueError, + match="cannot both vary during fitting", + ): + workflow.run_fit(entries, method="leastsq", max_nfev=50) + + +def test_solution_scattering_estimate_reports_attenuation_and_scale(): + estimate = calculate_solution_scattering_estimate( + SolutionScatteringEstimatorSettings( + solution=SolutionPropertiesSettings( + mode="mass", + solution_density=1.0, + solute_stoich="Cs1Pb1I3", + solvent_stoich="H2O", + molar_mass_solute=620.0, + molar_mass_solvent=18.015, + mass_solute=1.0, + mass_solvent=9.0, + ), + solute_density_g_per_ml=2.0, + solvent_density_g_per_ml=1.0, + beam=BeamGeometrySettings( + incident_energy_kev=17.0, + capillary_size_mm=1.0, + capillary_geometry="cylindrical", + beam_profile="uniform", + beam_footprint_width_mm=0.4, + beam_footprint_height_mm=0.4, + ), + ) + ) + + assert estimate.number_density_estimate is not None + assert estimate.number_density_estimate.number_density_a3 > 0.0 + assert estimate.volume_fraction_estimate is not None + assert estimate.interaction_contrast_estimate is not None + assert estimate.attenuation_estimate is not None + attenuation = estimate.attenuation_estimate + assert attenuation.solvent_scattering_scale_factor > 0.0 + assert attenuation.solvent_scattering_scale_factor < 1.0 + assert attenuation.sample_linear_attenuation_inv_cm > 0.0 + assert attenuation.sample_linear_attenuation_inv_cm > ( + attenuation.sample_solvent_linear_attenuation_inv_cm + ) + assert 0.0 < attenuation.sample_transmission < 1.0 + assert 0.0 < attenuation.neat_solvent_transmission < 1.0 + assert ( + estimate.interaction_contrast_estimate.saxs_effective_solute_interaction_ratio + < estimate.volume_fraction_estimate.solute_volume_fraction + ) + + +def test_solution_scattering_estimate_reports_saxs_effective_interaction_ratio(): + estimate = calculate_solution_scattering_estimate( + SolutionScatteringEstimatorSettings( + solution=SolutionPropertiesSettings( + mode="molarity_per_liter", + solution_density=1.1, + solute_stoich="Pb1I2", + solvent_stoich="C3H7NO", + molar_mass_solute=461.0, + molar_mass_solvent=73.09, + molarity=0.5, + molarity_element="Pb", + ), + solute_density_g_per_ml=None, + solvent_density_g_per_ml=0.94, + ) + ) + + assert estimate.volume_fraction_estimate is not None + assert estimate.interaction_contrast_estimate is not None + interaction = estimate.interaction_contrast_estimate + assert ( + interaction.physical_solute_associated_volume_fraction + == pytest.approx( + estimate.volume_fraction_estimate.solute_volume_fraction + ) + ) + assert 0.0 < interaction.saxs_effective_solute_interaction_ratio < 1.0 + assert 0.0 < interaction.saxs_effective_solvent_background_ratio < 1.0 + assert ( + interaction.saxs_effective_solute_interaction_ratio + == pytest.approx( + 1.0 - interaction.saxs_effective_solvent_background_ratio + ) + ) + assert ( + interaction.saxs_effective_solute_interaction_ratio + > interaction.physical_solute_associated_volume_fraction + ) + summary = estimate.summary_text() + assert "Physical solute-associated volume fraction estimate" in summary + assert "SAXS-effective interaction contrast estimate" in summary + assert "Model-facing solvent defaults" in summary + + +def test_solution_scattering_estimate_reports_fluorescence_lines(): + estimate = calculate_solution_scattering_estimate( + SolutionScatteringEstimatorSettings( + solution=SolutionPropertiesSettings( + mode="mass", + solution_density=1.0, + solute_stoich="Cs1Pb1I3", + solvent_stoich="H2O", + molar_mass_solute=620.0, + molar_mass_solvent=18.015, + mass_solute=1.0, + mass_solvent=9.0, + ), + solute_density_g_per_ml=2.0, + solvent_density_g_per_ml=1.0, + calculate_solute_volume_fraction=False, + calculate_solvent_scattering_contribution=False, + calculate_sample_fluorescence_yield=True, + ) + ) + + assert estimate.fluorescence_estimate is not None + fluorescence = estimate.fluorescence_estimate + assert fluorescence.total_primary_detected_yield > 0.0 + assert fluorescence.total_secondary_detected_yield >= 0.0 + assert fluorescence.line_estimates + + def test_run_prefit_preserves_manual_parameter_values_outside_old_bounds( tmp_path, ): @@ -481,10 +869,95 @@ def test_saxs_prefit_workflow_evaluates_solvent_contribution(tmp_path): assert np.allclose(evaluation.solvent_intensities, solvent_intensity) assert np.allclose( evaluation.solvent_contribution, - solvent_intensity * 0.5 * 2e-3, + solvent_intensity * 0.5, ) +def test_saxs_prefit_workflow_evaluates_monosq_structure_factor_trace( + tmp_path, +): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + workflow = SAXSPrefitWorkflow(project_dir) + entries = workflow.load_parameter_entries() + eff_r = 11.0 + vol_frac = 0.03 + for entry in entries: + if entry.name == "eff_r": + entry.value = eff_r + if entry.name == "vol_frac": + entry.value = vol_frac + + evaluation = workflow.evaluate(entries) + template_module = load_template_module( + "template_pd_likelihood_monosq_decoupled" + ) + + assert evaluation.structure_factor_trace is not None + assert np.allclose( + evaluation.structure_factor_trace, + template_module.calc_monodisperse_sq( + eff_r, + vol_frac, + evaluation.q_values, + ), + ) + + +def test_poly_lma_prefit_evaluates_structure_factor_trace(tmp_path): + project_dir, _paths, effective_radius = _build_poly_lma_geometry_project( + tmp_path, + template_name=POLY_LMA_HS_TEMPLATE, + ) + workflow = SAXSPrefitWorkflow(project_dir) + workflow.compute_cluster_geometry_table() + entries = workflow.load_parameter_entries() + phi_int = 0.02 + for entry in entries: + if entry.name == "phi_int": + entry.value = phi_int + + evaluation = workflow.evaluate(entries) + template_module = load_template_module(POLY_LMA_HS_TEMPLATE) + + assert evaluation.structure_factor_trace is not None + assert np.allclose( + evaluation.structure_factor_trace, + template_module.calc_hardsphere_sq( + effective_radius, + phi_int, + evaluation.q_values, + ), + ) + + +def test_poly_lma_prefit_clamps_saved_solvent_weight_bounds(tmp_path): + project_dir, _paths, _effective_radius = _build_poly_lma_geometry_project( + tmp_path + ) + workflow = SAXSPrefitWorkflow(project_dir) + entries = workflow.load_template_reset_entries() + for entry in entries: + if entry.name == "solvent_scale": + entry.value = 1.6 + entry.minimum = 0.0 + entry.maximum = 5.0 + workflow.settings.template_reset_parameter_entries = [ + entry.to_dict() for entry in entries + ] + workflow.project_manager.save_project(workflow.settings) + + reloaded = SAXSPrefitWorkflow(project_dir) + solvent_entry = next( + entry + for entry in reloaded.load_template_reset_entries() + if entry.name == "solvent_scale" + ) + + assert solvent_entry.value == pytest.approx(1.0) + assert solvent_entry.minimum == pytest.approx(0.0) + assert solvent_entry.maximum == pytest.approx(1.0) + + def test_saxs_prefit_workflow_persists_template_reset_and_best_prefit( tmp_path, ): diff --git a/tests/test_saxs_template_installation.py b/tests/test_saxs_template_installation.py index 6ced72c..8edfb23 100644 --- a/tests/test_saxs_template_installation.py +++ b/tests/test_saxs_template_installation.py @@ -372,7 +372,7 @@ def test_validate_template_candidate_rejects_cluster_geometry_binding_missing_fr # cluster_geometry_metadata: true # param: phi_solute,0.02,True,0.0,0.5 # param: phi_int,0.02,True,0.0,0.4 - # param: solvent_scale,1.0,True,0.0,5.0 + # param: solvent_scale,1.0,True,0.0,1.0 # param: scale,1.0,True,1e-8,1e8 # param: offset,0.0,True,-1.0,1.0 # param: log_sigma,-9.21,True,-20.0,5.0 diff --git a/tests/test_saxs_ui.py b/tests/test_saxs_ui.py index 93eebdd..9965ab0 100644 --- a/tests/test_saxs_ui.py +++ b/tests/test_saxs_ui.py @@ -12,10 +12,12 @@ import numpy as np import pytest from matplotlib import colormaps +from matplotlib.backends.backend_qtagg import NavigationToolbar2QT from matplotlib.collections import LineCollection, PolyCollection from matplotlib.colors import to_hex from PySide6.QtCore import Qt from PySide6.QtGui import QColor, QTextOption +from PySide6.QtTest import QTest from PySide6.QtWidgets import ( QAbstractScrollArea, QApplication, @@ -28,7 +30,9 @@ QMessageBox, QScrollArea, QSizePolicy, + QSplitter, ) +from scipy import stats import saxshell.saxs.project_manager.project as project_module from saxshell.saxs._model_templates import ( @@ -43,20 +47,29 @@ SAXSDreamWorkflow, load_dream_settings, ) +from saxshell.saxs.model_report import export_dream_model_report_pptx from saxshell.saxs.prefit import ( + PrefitEvaluation, + PrefitParameterEntry, SAXSPrefitWorkflow, compute_cluster_geometry_metadata, ) from saxshell.saxs.project_manager import ( ClusterImportResult, ExperimentalDataSummary, + PowerPointExportSettings, ProjectSettings, SAXSProjectManager, build_project_paths, load_experimental_data_file, ) +from saxshell.saxs.solution_scattering_estimator import ( + SolutionScatteringEstimatorSettings, + calculate_solution_scattering_estimate, +) from saxshell.saxs.template_installation import install_template_candidate from saxshell.saxs.ui.distribution_window import DistributionSetupWindow +from saxshell.saxs.ui.dream_tab import DreamTab from saxshell.saxs.ui.experimental_data_loader import ( ExperimentalDataHeaderDialog, ) @@ -65,12 +78,14 @@ RuntimeBundleOpener, SAXSMainWindow, TemplateInstallRequest, + launch_saxs_ui, ) -from saxshell.saxs.ui.prefit_tab import TableCellComboBox +from saxshell.saxs.ui.prefit_tab import PrefitTab, TableCellComboBox from saxshell.saxs.ui.prior_histogram_window import PriorHistogramWindow from saxshell.saxs.ui.project_setup_tab import ProjectSetupTab from saxshell.saxs.ui.solute_volume_fraction_widget import ( SOLUTE_VOLUME_FRACTION_CITATION_URL, + SOLUTE_VOLUME_FRACTION_HELP_TEXT, SoluteVolumeFractionWidget, ) from saxshell.version import __version__ @@ -87,6 +102,10 @@ def _table_column_index(table, label: str) -> int: raise AssertionError(f"Column {label!r} was not found.") +def _plot_lines_by_gid(axis, gid: str): + return [line for line in axis.get_lines() if line.get_gid() == gid] + + def _write_component_file(path, q_values, intensities): data = np.column_stack( [ @@ -419,6 +438,225 @@ def _write_weight_order_dream_results(tmp_path): return run_dir +def _write_violin_mode_split_dream_results(tmp_path): + run_dir = tmp_path / "dream_violin_mode_split_test" + run_dir.mkdir(parents=True) + metadata = { + "settings": {"burnin_percent": 0}, + "template_name": "template_pd_likelihood_monosq_decoupled", + "parameter_map": [ + { + "structure": "A", + "motif": "m1", + "param_type": "Both", + "param": "w0", + "value": 0.35, + "vary": True, + "distribution": "lognorm", + "dist_params": {"loc": 0.0, "scale": 0.35, "s": 0.1}, + }, + { + "structure": "A", + "motif": "m1", + "param_type": "Both", + "param": "r_eff_w0", + "value": 9.0, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 9.0, "scale": 0.5}, + }, + { + "structure": "B", + "motif": "m2", + "param_type": "Both", + "param": "a_eff_w1", + "value": 8.0, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 8.0, "scale": 0.5}, + }, + { + "structure": "B", + "motif": "m2", + "param_type": "Both", + "param": "b_eff_w1", + "value": 10.0, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 10.0, "scale": 0.5}, + }, + { + "structure": "B", + "motif": "m2", + "param_type": "Both", + "param": "c_eff_w1", + "value": 12.0, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 12.0, "scale": 0.5}, + }, + { + "structure": "", + "motif": "", + "param_type": "SAXS", + "param": "scale", + "value": 1.0, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 1.0, "scale": 0.1}, + }, + { + "structure": "", + "motif": "", + "param_type": "SAXS", + "param": "offset", + "value": 0.05, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 0.05, "scale": 0.01}, + }, + { + "structure": "", + "motif": "", + "param_type": "SAXS", + "param": "phi_int", + "value": 0.12, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 0.12, "scale": 0.01}, + }, + ], + "active_parameter_entries": [ + { + "structure": "A", + "motif": "m1", + "param_type": "Both", + "param": "w0", + "value": 0.35, + "vary": True, + "distribution": "lognorm", + "dist_params": {"loc": 0.0, "scale": 0.35, "s": 0.1}, + }, + { + "structure": "A", + "motif": "m1", + "param_type": "Both", + "param": "r_eff_w0", + "value": 9.0, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 9.0, "scale": 0.5}, + }, + { + "structure": "B", + "motif": "m2", + "param_type": "Both", + "param": "a_eff_w1", + "value": 8.0, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 8.0, "scale": 0.5}, + }, + { + "structure": "B", + "motif": "m2", + "param_type": "Both", + "param": "b_eff_w1", + "value": 10.0, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 10.0, "scale": 0.5}, + }, + { + "structure": "B", + "motif": "m2", + "param_type": "Both", + "param": "c_eff_w1", + "value": 12.0, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 12.0, "scale": 0.5}, + }, + { + "structure": "", + "motif": "", + "param_type": "SAXS", + "param": "scale", + "value": 1.0, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 1.0, "scale": 0.1}, + }, + { + "structure": "", + "motif": "", + "param_type": "SAXS", + "param": "offset", + "value": 0.05, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 0.05, "scale": 0.01}, + }, + { + "structure": "", + "motif": "", + "param_type": "SAXS", + "param": "phi_int", + "value": 0.12, + "vary": True, + "distribution": "norm", + "dist_params": {"loc": 0.12, "scale": 0.01}, + }, + ], + "active_parameter_indices": list(range(8)), + "full_parameter_names": [ + "w0", + "r_eff_w0", + "a_eff_w1", + "b_eff_w1", + "c_eff_w1", + "scale", + "offset", + "phi_int", + ], + "fixed_parameter_values": [ + 0.35, + 9.0, + 8.0, + 10.0, + 12.0, + 1.0, + 0.05, + 0.12, + ], + "q_values": [0.1, 0.2], + "experimental_intensities": [1.0, 0.8], + "theoretical_intensities": [[1.0, 0.9]], + "solvent_intensities": [0.0, 0.0], + } + (run_dir / "dream_runtime_metadata.json").write_text( + json.dumps(metadata, indent=2) + "\n", + encoding="utf-8", + ) + np.save( + run_dir / "dream_sampled_params.npy", + np.asarray( + [ + [ + [0.35, 9.0, 8.0, 10.0, 12.0, 1.0, 0.05, 0.12], + [0.36, 9.2, 8.1, 10.1, 12.1, 1.1, 0.07, 0.14], + ] + ], + dtype=float, + ), + ) + np.save( + run_dir / "dream_log_ps.npy", + np.asarray([[1.0, 2.0]], dtype=float), + ) + return run_dir + + @pytest.fixture(scope="module") def qapp(): os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") @@ -428,11 +666,25 @@ def qapp(): yield app +def _wait_for_dream_refresh(qapp, delay_ms: int = 120) -> None: + qapp.processEvents() + QTest.qWait(delay_ms) + qapp.processEvents() + + def test_saxs_main_window_loads_project_prefit_and_dream_tabs(qapp, tmp_path): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) window = SAXSMainWindow(initial_project_dir=project_dir) + brand_widget = window.tabs.cornerWidget(Qt.Corner.TopLeftCorner) + assert window.windowTitle() == "SAXSShell" + assert not window.windowIcon().isNull() + assert brand_widget is not None + brand_labels = [ + label.text() for label in brand_widget.findChildren(QLabel) + ] + assert "SAXSShell" in brand_labels assert window.project_setup_tab.forward_model_group.isEnabled() assert window.project_setup_tab.model_group.isEnabled() assert window.project_setup_tab.template_combo.count() >= 1 @@ -479,6 +731,9 @@ def test_saxs_main_window_loads_project_prefit_and_dream_tabs(qapp, tmp_path): assert window.prefit_tab.show_experimental_trace_checkbox.isChecked() assert window.prefit_tab.show_model_trace_checkbox.isChecked() assert not window.prefit_tab.show_solvent_trace_checkbox.isChecked() + assert ( + not window.prefit_tab.show_structure_factor_trace_checkbox.isChecked() + ) assert window.prefit_tab.log_x_checkbox.isChecked() assert window.prefit_tab.log_y_checkbox.isChecked() assert window.prefit_tab.plot_toolbar is not None @@ -510,6 +765,59 @@ def test_saxs_main_window_loads_project_prefit_and_dream_tabs(qapp, tmp_path): in window.prefit_tab.summary_box.toPlainText() ) assert "Prefit Console" in window.prefit_tab.summary_box.toPlainText() + assert window.dream_tab.export_model_report_button is not None + assert ( + window.dream_tab.export_model_report_button.text() + == "Export Model Report (PPTX)" + ) + assert window.dream_tab.recycle_button is not None + assert window.dream_tab.recycle_button.text() == "Recycle" + + +def test_launch_saxs_ui_shows_and_finishes_startup_splash(qapp, monkeypatch): + calls: list[object] = [] + + class _Splash: + def show(self): + calls.append("splash_show") + + def finish(self, window): + calls.append(("splash_finish", window)) + + def close(self): + calls.append("splash_close") + + class _Window: + def __init__(self, initial_project_dir=None): + self.initial_project_dir = initial_project_dir + calls.append(("window_init", initial_project_dir)) + + def show(self): + calls.append("window_show") + + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.configure_saxshell_application", + lambda app: calls.append(("configure", app)), + ) + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.create_saxshell_startup_splash", + lambda: _Splash(), + ) + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.SAXSMainWindow", + _Window, + ) + + project_path = Path("/tmp/brand_test_project") + result = launch_saxs_ui(project_path) + + assert result == 0 + assert calls[0][0] == "configure" + assert calls[1] == "splash_show" + assert calls[2] == ("window_init", project_path) + assert calls[3] == "window_show" + assert calls[4][0] == "splash_finish" + assert isinstance(calls[4][1], _Window) def test_prefit_ionic_radius_help_button_shows_citation( @@ -558,8 +866,10 @@ def test_prefit_solute_volume_fraction_help_button_shows_citation( window.prefit_tab.solute_volume_fraction_help_button.click() assert messages - assert messages[-1][0] == "Solute Volume Fraction Estimate Help" - assert "phi_solute ~= c_solute * vbar_solute" in messages[-1][1] + assert messages[-1][0] == "Solution Scattering Estimate Help" + assert "Physical solute-associated volume fraction" in messages[-1][1] + assert "SAXS-effective interaction contrast ratio" in messages[-1][1] + assert "Fluorescence background proxy" in messages[-1][1] assert SOLUTE_VOLUME_FRACTION_CITATION_URL in messages[-1][1] window.close() @@ -574,7 +884,10 @@ def test_prefit_solute_volume_fraction_estimator_is_template_aware_and_applies( ) base_window = SAXSMainWindow(initial_project_dir=base_project_dir) - assert base_window.prefit_tab._solute_volume_fraction_group.isHidden() + assert not base_window.prefit_tab._solute_volume_fraction_group.isHidden() + assert "solv_w" in ( + base_window.prefit_tab.solute_volume_fraction_status_label.text() + ) poly_project_dir, _poly_paths = _build_poly_lma_geometry_project( tmp_path / "poly" @@ -585,6 +898,9 @@ def test_prefit_solute_volume_fraction_estimator_is_template_aware_and_applies( assert "phi_solute" in ( poly_window.prefit_tab.solute_volume_fraction_status_label.text() ) + assert "solvent_scale" in ( + poly_window.prefit_tab.solute_volume_fraction_status_label.text() + ) assert poly_window.prefit_tab.solute_volume_fraction_is_collapsed() poly_window.prefit_tab.solute_volume_fraction_collapse_button.click() @@ -601,14 +917,56 @@ def test_prefit_solute_volume_fraction_estimator_is_template_aware_and_applies( widget.solute_density_spin.setValue(2.0) widget.solvent_density_spin.setValue(1.0) + expected_settings = widget.current_estimator_settings() + expected_estimate = calculate_solution_scattering_estimate( + SolutionScatteringEstimatorSettings( + solution=expected_settings.solution, + solute_density_g_per_ml=expected_settings.solute_density_g_per_ml, + solvent_density_g_per_ml=expected_settings.solvent_density_g_per_ml, + calculate_number_density=expected_settings.calculate_number_density, + calculate_solute_volume_fraction=( + expected_settings.calculate_solute_volume_fraction + ), + calculate_solvent_scattering_contribution=( + expected_settings.calculate_solvent_scattering_contribution + ), + calculate_sample_fluorescence_yield=( + expected_settings.calculate_sample_fluorescence_yield + ), + beam=expected_settings.beam, + ) + ) + widget.calculate_button.click() phi_row = poly_window.prefit_tab.find_parameter_row("phi_solute") assert phi_row >= 0 assert float( poly_window.prefit_tab.parameter_table.item(phi_row, 3).text() - ) == pytest.approx(0.5 / 10.0, rel=1e-3) - assert "Applied estimate to phi_solute" in widget.output_box.toPlainText() + ) == pytest.approx( + expected_estimate.interaction_contrast_estimate.saxs_effective_solute_interaction_ratio, + rel=1e-3, + ) + assert ( + poly_window.prefit_tab.parameter_table.item(phi_row, 4).checkState() + == Qt.CheckState.Unchecked + ) + solvent_row = poly_window.prefit_tab.find_parameter_row("solvent_scale") + assert solvent_row >= 0 + assert float( + poly_window.prefit_tab.parameter_table.item(solvent_row, 3).text() + ) == pytest.approx( + expected_estimate.attenuation_estimate.solvent_scattering_scale_factor, + rel=1e-3, + ) + assert ( + poly_window.prefit_tab.parameter_table.item( + solvent_row, 4 + ).checkState() + == Qt.CheckState.Unchecked + ) + assert "Applied phi_solute =" in widget.output_box.toPlainText() + assert "Applied solvent_scale =" in widget.output_box.toPlainText() assert ( "Estimated solute volume: 0.500 cm^3" in widget.output_box.toPlainText() @@ -618,7 +976,11 @@ def test_prefit_solute_volume_fraction_estimator_is_template_aware_and_applies( in widget.output_box.toPlainText() ) assert ( - "Applied estimate to phi_solute = 0.050000." + "Recommended solvent scattering scale factor:" + in widget.output_box.toPlainText() + ) + assert ( + "SAXS-effective solute interaction ratio:" in widget.output_box.toPlainText() ) assert not widget.output_is_collapsed() @@ -631,6 +993,61 @@ def test_prefit_solute_volume_fraction_estimator_is_template_aware_and_applies( poly_window.close() +def test_prefit_single_solvent_weight_uses_combined_saxs_effective_multiplier( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + + window.prefit_tab.solute_volume_fraction_collapse_button.click() + widget = window.prefit_tab.solute_volume_fraction_widget + widget.solution_density_spin.setValue(1.0) + widget.solute_stoich_edit.setText("Cs1Pb1I3") + widget.solvent_stoich_edit.setText("H2O") + widget.molar_mass_solute_spin.setValue(620.0) + widget.molar_mass_solvent_spin.setValue(18.015) + widget.mass_solute_spin.setValue(1.0) + widget.mass_solvent_spin.setValue(9.0) + widget.solute_density_spin.setValue(2.0) + widget.solvent_density_spin.setValue(1.0) + + expected_settings = widget.current_estimator_settings() + expected_estimate = calculate_solution_scattering_estimate( + SolutionScatteringEstimatorSettings( + solution=expected_settings.solution, + solute_density_g_per_ml=expected_settings.solute_density_g_per_ml, + solvent_density_g_per_ml=expected_settings.solvent_density_g_per_ml, + calculate_number_density=expected_settings.calculate_number_density, + calculate_solute_volume_fraction=( + expected_settings.calculate_solute_volume_fraction + ), + calculate_solvent_scattering_contribution=( + expected_settings.calculate_solvent_scattering_contribution + ), + calculate_sample_fluorescence_yield=( + expected_settings.calculate_sample_fluorescence_yield + ), + beam=expected_settings.beam, + ) + ) + + widget.calculate_button.click() + + solvent_row = window.prefit_tab.find_parameter_row("solv_w") + assert solvent_row >= 0 + assert float( + window.prefit_tab.parameter_table.item(solvent_row, 3).text() + ) == pytest.approx( + expected_estimate.attenuation_estimate.solvent_scattering_scale_factor + * expected_estimate.interaction_contrast_estimate.saxs_effective_solvent_background_ratio, + rel=1e-3, + ) + assert "Applied solv_w =" in widget.output_box.toPlainText() + window.close() + + def test_prefit_solute_volume_fraction_widget_hides_solute_density_in_molarity_mode( qapp, tmp_path, @@ -655,7 +1072,7 @@ def test_prefit_solute_volume_fraction_widget_hides_solute_density_in_molarity_m widget.current_estimator_settings().solvent_density_g_per_ml == pytest.approx(widget.solvent_density_spin.value()) ) - assert "solute density is hidden in molarity mode" in ( + assert "solvent density remains active" in ( widget.solution_mode_hint_label.text().lower() ) window.close() @@ -680,25 +1097,144 @@ def test_solute_volume_fraction_widget_loads_builtin_solvent_density_presets( assert widget.solvent_density_spin.value() == pytest.approx(1.10) -def test_prefit_fallback_preserves_selected_template_when_workflow_not_ready( - qapp, - tmp_path, -): +def test_solution_scattering_widget_loads_builtin_beam_presets(qapp): del qapp - project_dir, paths = _build_poly_lma_geometry_project(tmp_path) - (paths.project_dir / "md_prior_weights.json").unlink() - - window = SAXSMainWindow(initial_project_dir=project_dir) + widget = SoluteVolumeFractionWidget() - assert window.prefit_workflow is None assert ( - window.prefit_tab.selected_template_name() == POLY_LMA_HS_MIX_TEMPLATE + widget.beam_preset_combo.currentData() == "NSLS-II 28-ID-1 (default)" ) - assert not window.prefit_tab._solute_volume_fraction_group.isHidden() - assert not window.prefit_tab._cluster_geometry_group.isHidden() - assert window.prefit_tab.parameter_table.rowCount() == 0 - assert "Prefit workflow is not ready yet." in ( - window.prefit_tab.output_box.toPlainText() + assert widget.incident_energy_spin.value() == pytest.approx(74.0) + assert widget.capillary_size_spin.value() == pytest.approx(1.0) + assert widget.beam_footprint_width_spin.value() == pytest.approx(0.4) + assert widget.beam_footprint_height_spin.value() == pytest.approx(0.4) + + focused_index = widget.beam_preset_combo.findData( + "APS 5-IDD (default - focused)" + ) + assert focused_index >= 0 + widget.beam_preset_combo.setCurrentIndex(focused_index) + widget._load_selected_beam_preset() + + assert widget.incident_energy_spin.value() == pytest.approx(17.0) + assert widget.beam_footprint_width_spin.value() == pytest.approx(0.05) + assert widget.beam_footprint_height_spin.value() == pytest.approx(1.0) + + unfocused_index = widget.beam_preset_combo.findData( + "APS 5-IDD (default - unfocused)" + ) + assert unfocused_index >= 0 + widget.beam_preset_combo.setCurrentIndex(unfocused_index) + widget._load_selected_beam_preset() + + assert widget.incident_energy_spin.value() == pytest.approx(17.5) + assert widget.beam_footprint_width_spin.value() == pytest.approx(1.0) + assert widget.beam_footprint_height_spin.value() == pytest.approx(1.0) + + +def test_solution_scattering_widget_saves_and_deletes_custom_beam_presets( + qapp, + tmp_path, + monkeypatch, +): + del qapp + preset_path = tmp_path / "beam_presets.json" + monkeypatch.setenv( + "SAXSHELL_BEAM_GEOMETRY_PRESETS_PATH", + str(preset_path), + ) + widget = SoluteVolumeFractionWidget() + + widget.incident_energy_spin.setValue(12.4) + widget.capillary_size_spin.setValue(2.0) + widget.beam_footprint_width_spin.setValue(0.25) + widget.beam_footprint_height_spin.setValue(0.75) + + monkeypatch.setattr( + "saxshell.saxs.ui.solution_scattering_widget.QInputDialog.getText", + lambda *args, **kwargs: ("Custom Beam", True), + ) + + widget._save_current_beam_preset() + + saved_payload = json.loads(preset_path.read_text(encoding="utf-8")) + assert "Custom Beam" in saved_payload["presets"] + assert saved_payload["presets"]["Custom Beam"][ + "incident_energy_kev" + ] == pytest.approx(12.4) + assert widget.beam_preset_combo.findData("Custom Beam") >= 0 + + widget.beam_preset_combo.setCurrentIndex( + widget.beam_preset_combo.findData("Custom Beam") + ) + widget._delete_selected_beam_preset() + + reloaded_payload = json.loads(preset_path.read_text(encoding="utf-8")) + assert "Custom Beam" not in reloaded_payload["presets"] + assert widget.beam_preset_combo.findData("Custom Beam") < 0 + + +def test_solution_scattering_widget_wavelength_dialog_tracks_energy(qapp): + del qapp + widget = SoluteVolumeFractionWidget() + + widget.incident_energy_spin.setValue(17.0) + widget._show_wavelength_dialog() + + dialog = widget._wavelength_dialog + assert dialog is not None + assert dialog.windowTitle() == "Beam Energy and Wavelength" + assert dialog.energy_value_label.text() == "17" + assert float(dialog.wavelength_value_label.text()) == pytest.approx( + 12.398419843320026 / 17.0, + rel=1e-6, + ) + + widget.incident_energy_spin.setValue(74.0) + QApplication.processEvents() + + assert dialog.energy_value_label.text() == "74" + assert float(dialog.wavelength_value_label.text()) == pytest.approx( + 12.398419843320026 / 74.0, + rel=1e-6, + ) + dialog.close() + + +def test_solution_scattering_widget_places_output_below_run_button(qapp): + del qapp + widget = SoluteVolumeFractionWidget() + widget.resize(960, 720) + widget.set_output_collapsed(False) + widget.show() + QApplication.processEvents() + + assert ( + widget.output_panel.parentWidget() + is widget.calculate_button.parentWidget() + ) + assert widget.output_panel.y() > widget.calculate_button.y() + + +def test_prefit_fallback_preserves_selected_template_when_workflow_not_ready( + qapp, + tmp_path, +): + del qapp + project_dir, paths = _build_poly_lma_geometry_project(tmp_path) + (paths.project_dir / "md_prior_weights.json").unlink() + + window = SAXSMainWindow(initial_project_dir=project_dir) + + assert window.prefit_workflow is None + assert ( + window.prefit_tab.selected_template_name() == POLY_LMA_HS_MIX_TEMPLATE + ) + assert not window.prefit_tab._solute_volume_fraction_group.isHidden() + assert not window.prefit_tab._cluster_geometry_group.isHidden() + assert window.prefit_tab.parameter_table.rowCount() == 0 + assert "Prefit workflow is not ready yet." in ( + window.prefit_tab.output_box.toPlainText() ) window.close() @@ -1485,6 +2021,11 @@ def test_poly_lma_dream_parameter_map_tracks_cluster_geometry_shape( sphere_names = [entry.param for entry in sphere_entries] assert "r_eff_w0" in sphere_names assert "a_eff_w0" not in sphere_names + sphere_radius_entry = next( + entry for entry in sphere_entries if entry.param == "r_eff_w0" + ) + assert sphere_radius_entry.distribution == "lognorm" + assert set(sphere_radius_entry.dist_params) == {"loc", "scale", "s"} ellipsoid_index = sf_combo.findData("ellipsoid") assert ellipsoid_index >= 0 @@ -1496,6 +2037,10 @@ def test_poly_lma_dream_parameter_map_tracks_cluster_geometry_shape( assert "a_eff_w0" in ellipsoid_names assert "b_eff_w0" in ellipsoid_names assert "c_eff_w0" in ellipsoid_names + ellipsoid_a_entry = next( + entry for entry in ellipsoid_entries if entry.param == "a_eff_w0" + ) + assert ellipsoid_a_entry.distribution == "norm" window.close() @@ -1518,16 +2063,21 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): assert window.cluster_action.text() == "Open Cluster Extraction" assert window.xyz2pdb_action.text() == "Open xyz2pdb Conversion" assert window.bondanalysis_action.text() == "Open Bond Analysis" + assert window.clusterdynamics_action.text() == "Open Cluster Dynamics" + assert window.clusterdynamicsml_action.text() == "Open Cluster Dynamics ML" assert window.fullrmc_action.text() == "Open fullrmc Setup" assert ( window.volume_fraction_action.text() == "Open Volume Fraction Estimate" ) + assert ( + window.number_density_action.text() == "Open Number Density Estimate" + ) assert window.settings_menu.title() == "Settings" assert ( - window.dream_output_settings_action.text() - == "DREAM Output Settings..." + window.console_autoscroll_action.text() == "Autoscroll Console Output" ) + assert window.dream_output_settings_action.text() == "Main UI Settings..." assert window.help_menu.title() == "Help" assert window.version_info_action.text() == "Version Information" @@ -1542,6 +2092,112 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): assert "keith.white@colorado.edu" in version_text +def test_console_autoscroll_setting_controls_tab_output_scroll( + qapp, tmp_path, monkeypatch +): + class _FakeSettings: + def __init__(self): + self.values: dict[str, object] = {} + + def value(self, key, default=None): + return self.values.get(key, default) + + def setValue(self, key, value): + self.values[key] = value + + settings_store = _FakeSettings() + monkeypatch.setattr( + SAXSMainWindow, + "_recent_projects_settings", + lambda self: settings_store, + ) + + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + window.resize(920, 720) + window.project_setup_tab.summary_box.setFixedHeight(90) + window.prefit_tab.output_box.setFixedHeight(90) + window.dream_tab.output_box.setFixedHeight(90) + window.show() + qapp.processEvents() + + assert window.console_autoscroll_action.isChecked() + + def _activate_tab(tab): + window.tabs.setCurrentWidget(tab) + qapp.processEvents() + + def _append_project_lines(start: int, stop: int): + _activate_tab(window.project_setup_tab) + for index in range(start, stop): + window.project_setup_tab.append_summary( + f"project summary line {index}" + ) + qapp.processEvents() + return window.project_setup_tab.summary_box.verticalScrollBar() + + def _append_prefit_lines(start: int, stop: int): + _activate_tab(window.prefit_tab) + for index in range(start, stop): + window.prefit_tab.append_log(f"prefit log line {index}") + qapp.processEvents() + return window.prefit_tab.output_box.verticalScrollBar() + + def _append_dream_lines(start: int, stop: int): + _activate_tab(window.dream_tab) + for index in range(start, stop): + window.dream_tab.append_log(f"dream log line {index}") + qapp.processEvents() + return window.dream_tab.output_box.verticalScrollBar() + + def _is_near_bottom(scrollbar) -> bool: + return scrollbar.maximum() - scrollbar.value() <= max( + int(scrollbar.pageStep()), 4 + ) + + project_scrollbar = _append_project_lines(0, 80) + prefit_scrollbar = _append_prefit_lines(0, 80) + dream_scrollbar = _append_dream_lines(0, 80) + + assert project_scrollbar.maximum() > 0 + assert prefit_scrollbar.maximum() > 0 + assert dream_scrollbar.maximum() > 0 + assert _is_near_bottom(project_scrollbar) + assert _is_near_bottom(prefit_scrollbar) + assert _is_near_bottom(dream_scrollbar) + + window.console_autoscroll_action.trigger() + qapp.processEvents() + + assert settings_store.values["console_autoscroll_enabled"] is False + assert not window.console_autoscroll_action.isChecked() + + project_scrollbar.setValue(0) + prefit_scrollbar.setValue(0) + dream_scrollbar.setValue(0) + qapp.processEvents() + + project_scrollbar = _append_project_lines(80, 100) + prefit_scrollbar = _append_prefit_lines(80, 100) + dream_scrollbar = _append_dream_lines(80, 100) + + assert not _is_near_bottom(project_scrollbar) + assert not _is_near_bottom(prefit_scrollbar) + assert not _is_near_bottom(dream_scrollbar) + + window.console_autoscroll_action.trigger() + qapp.processEvents() + + assert settings_store.values["console_autoscroll_enabled"] is True + project_scrollbar = _append_project_lines(100, 101) + prefit_scrollbar = _append_prefit_lines(100, 101) + dream_scrollbar = _append_dream_lines(100, 101) + assert _is_near_bottom(project_scrollbar) + assert _is_near_bottom(prefit_scrollbar) + assert _is_near_bottom(dream_scrollbar) + window.close() + + def test_volume_fraction_tool_window_opens_with_citation_and_target( qapp, tmp_path, @@ -1558,6 +2214,10 @@ def test_volume_fraction_tool_window_opens_with_citation_and_target( labels = [label.text() for label in tool_window.findChildren(QLabel)] assert any(SOLUTE_VOLUME_FRACTION_CITATION_URL in text for text in labels) assert "phi_solute" in tool_window.estimator_widget.target_label.text() + assert "solvent_scale" in tool_window.estimator_widget.target_label.text() + assert ( + tool_window.estimator_widget.calculate_number_density_checkbox.isChecked() + ) molarity_index = tool_window.estimator_widget.solution_mode_combo.findData( "molarity_per_liter" ) @@ -1569,14 +2229,102 @@ def test_volume_fraction_tool_window_opens_with_citation_and_target( assert tool_window.estimator_widget.solute_density_spin.isHidden() assert not tool_window.estimator_widget.solvent_density_label.isHidden() assert not tool_window.estimator_widget.solvent_density_spin.isHidden() - assert "solvent-density closure" in ( + assert "volume-closure calculations" in ( tool_window.estimator_widget.solution_mode_hint_label.text().lower() ) + tool_window.estimator_widget.solution_density_spin.setValue(1.0) + tool_window.estimator_widget.solute_stoich_edit.setText("Cs1Pb1I3") + tool_window.estimator_widget.solvent_stoich_edit.setText("H2O") + tool_window.estimator_widget.molar_mass_solute_spin.setValue(620.0) + tool_window.estimator_widget.molar_mass_solvent_spin.setValue(18.015) + tool_window.estimator_widget.molarity_spin.setValue(0.5) + tool_window.estimator_widget.molarity_element_edit.setText("Pb") + tool_window.estimator_widget.solvent_density_spin.setValue(1.0) + tool_window.estimator_widget.calculate_button.click() + + assert "Number density estimate" in ( + tool_window.estimator_widget.output_box.toPlainText() + ) + assert "atoms/A^3" in tool_window.estimator_widget.output_box.toPlainText() + tool_window.close() window.close() +def test_number_density_attenuation_and_fluorescence_tool_windows_open_with_defaults( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_poly_lma_geometry_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + + window._open_number_density_tool() + number_density_window = window._number_density_tool_window + assert number_density_window is not None + assert number_density_window.windowTitle() == "Number Density Estimate" + assert ( + number_density_window.estimator_widget.calculate_number_density_checkbox.isChecked() + ) + assert ( + not number_density_window.estimator_widget.calculate_volume_fraction_checkbox.isChecked() + ) + assert ( + not number_density_window.estimator_widget.calculate_attenuation_checkbox.isChecked() + ) + assert ( + not number_density_window.estimator_widget.calculate_fluorescence_checkbox.isChecked() + ) + + window._open_attenuation_tool() + attenuation_window = window._attenuation_tool_window + assert attenuation_window is not None + assert attenuation_window.windowTitle() == "Attenuation Estimate" + assert ( + not attenuation_window.estimator_widget.calculate_number_density_checkbox.isChecked() + ) + assert ( + attenuation_window.estimator_widget.calculate_volume_fraction_checkbox.isChecked() + ) + assert ( + attenuation_window.estimator_widget.calculate_attenuation_checkbox.isChecked() + ) + assert ( + not attenuation_window.estimator_widget.calculate_fluorescence_checkbox.isChecked() + ) + assert ( + "solvent_scale" + in attenuation_window.estimator_widget.target_label.text() + ) + + window._open_fluorescence_tool() + fluorescence_window = window._fluorescence_tool_window + assert fluorescence_window is not None + assert fluorescence_window.windowTitle() == "Fluorescence Estimate" + assert ( + not fluorescence_window.estimator_widget.calculate_number_density_checkbox.isChecked() + ) + assert ( + not fluorescence_window.estimator_widget.calculate_volume_fraction_checkbox.isChecked() + ) + assert ( + not fluorescence_window.estimator_widget.calculate_attenuation_checkbox.isChecked() + ) + assert ( + fluorescence_window.estimator_widget.calculate_fluorescence_checkbox.isChecked() + ) + assert ( + "solvent_scale" + in fluorescence_window.estimator_widget.target_label.text() + ) + + number_density_window.close() + attenuation_window.close() + fluorescence_window.close() + window.close() + + def test_contact_action_opens_developer_contact_window(qapp, monkeypatch): del qapp window = SAXSMainWindow() @@ -1661,6 +2409,94 @@ def raise_(self): assert launched["instance"] in window._child_tool_windows +def test_cluster_dynamics_tool_uses_active_project_dir( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + launched: dict[str, object] = {} + + class FakeClusterDynamicsWindow: + def __init__( + self, + initial_frames_dir=None, + initial_energy_file=None, + initial_project_dir=None, + ): + launched["frames_dir"] = initial_frames_dir + launched["energy_file"] = initial_energy_file + launched["project_dir"] = initial_project_dir + launched["instance"] = self + + def show(self): + launched["shown"] = True + + def raise_(self): + launched["raised"] = True + + monkeypatch.setattr( + "saxshell.clusterdynamics.ui.main_window.ClusterDynamicsMainWindow", + FakeClusterDynamicsWindow, + ) + + window._open_clusterdynamics_tool() + + assert ( + launched["project_dir"] + == Path(window.current_settings.project_dir).resolve() + ) + assert launched["shown"] is True + assert launched["raised"] is True + assert launched["instance"] in window._child_tool_windows + + +def test_cluster_dynamics_ml_tool_uses_active_project_dir( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + launched: dict[str, object] = {} + + class FakeClusterDynamicsMLWindow: + def __init__( + self, + initial_frames_dir=None, + initial_energy_file=None, + initial_project_dir=None, + initial_clusters_dir=None, + initial_experimental_data_file=None, + ): + launched["frames_dir"] = initial_frames_dir + launched["energy_file"] = initial_energy_file + launched["project_dir"] = initial_project_dir + launched["clusters_dir"] = initial_clusters_dir + launched["experimental_data_file"] = initial_experimental_data_file + launched["instance"] = self + + def show(self): + launched["shown"] = True + + def raise_(self): + launched["raised"] = True + + monkeypatch.setattr( + "saxshell.clusterdynamicsml.ui.main_window.ClusterDynamicsMLMainWindow", + FakeClusterDynamicsMLWindow, + ) + + window._open_clusterdynamicsml_tool() + + assert ( + launched["project_dir"] + == Path(window.current_settings.project_dir).resolve() + ) + assert launched["shown"] is True + assert launched["raised"] is True + assert launched["instance"] in window._child_tool_windows + + def test_save_project_as_copies_project_and_rewrites_internal_paths( qapp, tmp_path, monkeypatch ): @@ -1825,6 +2661,9 @@ def test_dream_progress_label_wrap_does_not_resize_left_pane(qapp): assert window.dream_tab.show_experimental_trace_checkbox.isChecked() assert window.dream_tab.show_model_trace_checkbox.isChecked() assert not window.dream_tab.show_solvent_trace_checkbox.isChecked() + assert ( + not window.dream_tab.show_structure_factor_trace_checkbox.isChecked() + ) assert window.dream_tab.model_log_x_checkbox.isChecked() assert window.dream_tab.model_log_y_checkbox.isChecked() assert ( @@ -1851,6 +2690,28 @@ def test_dream_progress_label_wrap_does_not_resize_left_pane(qapp): window.dream_tab.violin_value_scale_combo.currentData() == "parameter_value" ) + assert ( + window.dream_tab.violin_value_scale_combo.findData( + "effective_radii_only" + ) + >= 0 + ) + assert ( + window.dream_tab.violin_value_scale_combo.findData( + "additional_parameters_only" + ) + >= 0 + ) + assert ( + window.dream_tab.violin_mode_combo.findData("effective_radii_only") + >= 0 + ) + assert ( + window.dream_tab.violin_mode_combo.findData( + "additional_parameters_only" + ) + >= 0 + ) assert window.dream_tab.violin_palette_combo.currentData() == "Blues" assert window.dream_tab.selected_violin_point_color() == to_hex( "tab:red", @@ -1911,6 +2772,83 @@ def test_apply_dream_output_settings_updates_verbose_controls(qapp): ) +def test_apply_powerpoint_export_settings_updates_project_state( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + + window._apply_powerpoint_export_settings( + PowerPointExportSettings( + font_family="Courier New", + component_color_map="plasma", + prior_histogram_color_map="cividis", + solvent_sort_histogram_color_map="magma", + include_output_summary=False, + generate_manifest=False, + export_figure_assets=False, + ) + ) + + settings = window.current_settings.powerpoint_export_settings + assert settings.font_family == "Courier New" + assert settings.component_color_map == "plasma" + assert settings.prior_histogram_color_map == "cividis" + assert settings.solvent_sort_histogram_color_map == "magma" + assert not settings.include_output_summary + assert not settings.generate_manifest + assert not settings.export_figure_assets + assert "Updated PowerPoint export settings." in ( + window.dream_tab.output_box.toPlainText() + ) + + +def test_dream_tab_batches_runtime_output_console_updates(qapp, monkeypatch): + del qapp + tab = DreamTab() + render_calls: list[bool] = [] + original_render = tab._render_output + + def tracked_render(*, scroll_to_end: bool = False): + render_calls.append(bool(scroll_to_end)) + return original_render(scroll_to_end=scroll_to_end) + + monkeypatch.setattr(tab, "_render_output", tracked_render) + + tab.append_runtime_output("line 1") + tab.append_runtime_output("line 2") + QApplication.processEvents() + + assert "line 1" not in tab.output_box.toPlainText() + assert not render_calls + + QTest.qWait(tab.RUNTIME_OUTPUT_FLUSH_INTERVAL_MS + 75) + QApplication.processEvents() + + text = tab.output_box.toPlainText() + assert "DREAM Runtime Output" in text + assert "line 1" in text + assert "line 2" in text + assert len(render_calls) == 1 + + +def test_dream_tab_flushes_pending_runtime_output_before_regular_logs(qapp): + del qapp + tab = DreamTab() + + tab.set_log_text("Base log") + tab.append_runtime_output("runtime line") + tab.append_log("final log line") + QApplication.processEvents() + + text = tab.output_box.toPlainText() + assert "runtime line" in text + assert "final log line" in text + assert text.index("runtime line") < text.index("final log line") + + def test_dream_search_filter_presets_update_and_fall_back_to_custom(qapp): del qapp window = SAXSMainWindow() @@ -2057,7 +2995,7 @@ def test_distribution_window_prompts_before_quitting_without_first_save( assert not window.isVisible() -def test_distribution_window_switches_lognorm_params_for_norm_and_uniform( +def test_distribution_window_layout_uses_splitters_for_plot_table_and_console( qapp, ): del qapp @@ -2075,26 +3013,27 @@ def test_distribution_window_switches_lognorm_params_for_norm_and_uniform( ) ] ) - combo = window.table.cellWidget(0, 6) - - combo.setCurrentText("norm") + window.show() QApplication.processEvents() - norm_params = json.loads(window.table.item(0, 7).text()) - assert "s" not in norm_params - assert set(norm_params) == {"loc", "scale"} - assert window.figure.axes[0].get_title() == "w0: norm" - - combo.setCurrentText("uniform") - QApplication.processEvents() + assert isinstance(window._main_splitter, QSplitter) + assert window._main_splitter.orientation() == Qt.Orientation.Horizontal + assert window._main_splitter.count() == 2 + assert window._main_splitter.widget(0) is window._left_panel + assert window._main_splitter.widget(1) is window._plot_panel + assert all(size > 0 for size in window._main_splitter.sizes()) - uniform_params = json.loads(window.table.item(0, 7).text()) - assert "s" not in uniform_params - assert set(uniform_params) == {"loc", "scale"} - assert window.figure.axes[0].get_title() == "w0: uniform" + assert isinstance(window._left_splitter, QSplitter) + assert window._left_splitter.orientation() == Qt.Orientation.Vertical + assert window._left_splitter.count() == 2 + assert window._left_splitter.widget(0) is window._editor_panel + assert window._left_splitter.widget(1) is window.console + assert all(size > 0 for size in window._left_splitter.sizes()) -def test_distribution_window_can_toggle_all_vary_flags(qapp): +def test_distribution_window_interactive_center_lock_toggles_center_handle( + qapp, +): del qapp window = DistributionSetupWindow( [ @@ -2102,33 +3041,147 @@ def test_distribution_window_can_toggle_all_vary_flags(qapp): structure="PbI2", motif="motif_A", param_type="SAXS", - param="w0", - value=0.6, - vary=True, - distribution="lognorm", - dist_params={"loc": 0.0, "scale": 0.6, "s": 0.1}, - ), + param="scale", + value=1.0, + vary=True, + distribution="norm", + dist_params={"loc": 1.0, "scale": 0.2}, + ) + ] + ) + window.show() + QApplication.processEvents() + + assert window.lock_center_checkbox.isChecked() + assert "currently locked" in window.interactive_hint_label.text() + assert window._interactive_handles is not None + assert window._interactive_handles.center is None + + window.lock_center_checkbox.setChecked(False) + QApplication.processEvents() + + assert "red center handle" in window.interactive_hint_label.text() + assert window._interactive_handles is not None + assert window._interactive_handles.center is not None + + +def test_distribution_window_plot_is_square_and_shows_reset_baseline(qapp): + del qapp + window = DistributionSetupWindow( + [ DreamParameterEntry( structure="PbI2", - motif="motif_B", + motif="motif_A", param_type="SAXS", param="scale", value=1.0, - vary=False, + vary=True, distribution="norm", - dist_params={"loc": 1.0, "scale": 0.2}, - ), + dist_params={"loc": 1.0, "scale": 1.0}, + ) ] ) + window.show() + QApplication.processEvents() - window.set_all_vary_off_button.click() - assert all(not entry.vary for entry in window.current_entries()) + axis = window.figure.axes[0] + baseline_lines = _plot_lines_by_gid(axis, "reset-baseline") + current_lines = _plot_lines_by_gid(axis, "current-distribution") - window.set_all_vary_on_button.click() - assert all(entry.vary for entry in window.current_entries()) + assert float(axis.get_box_aspect()) == pytest.approx(1.0) + assert len(baseline_lines) == 1 + assert len(current_lines) == 1 + assert np.allclose( + baseline_lines[0].get_xdata(), + current_lines[0].get_xdata(), + ) + assert np.allclose( + baseline_lines[0].get_ydata(), + current_lines[0].get_ydata(), + ) -def test_distribution_window_recommended_vary_selection_keeps_radius_params_off( +def test_distribution_window_width_drag_updates_norm_and_uniform_params(): + norm_entry = DreamParameterEntry( + structure="PbI2", + motif="motif_A", + param_type="SAXS", + param="scale", + value=1.0, + vary=True, + distribution="norm", + dist_params={"loc": 1.0, "scale": 0.2}, + ) + updated_norm = DistributionSetupWindow._width_drag_adjusted_entry( + norm_entry, + handle_kind="right_width", + target_x=1.9, + ) + assert updated_norm.value == pytest.approx(1.0) + assert updated_norm.dist_params["loc"] == pytest.approx(1.0) + assert updated_norm.dist_params["scale"] == pytest.approx(0.3) + + uniform_entry = DreamParameterEntry( + structure="PbI2", + motif="motif_A", + param_type="SAXS", + param="phi_solute", + value=2.0, + vary=True, + distribution="uniform", + dist_params={"loc": 1.6, "scale": 0.8}, + ) + updated_uniform = DistributionSetupWindow._width_drag_adjusted_entry( + uniform_entry, + handle_kind="left_width", + target_x=1.2, + ) + assert updated_uniform.value == pytest.approx(2.0) + assert updated_uniform.dist_params["scale"] == pytest.approx(1.6) + assert updated_uniform.dist_params["loc"] == pytest.approx(1.2) + + +def test_distribution_window_center_and_peak_drag_adjust_lognorm_cleanly(): + entry = DreamParameterEntry( + structure="PbI2", + motif="motif_A", + param_type="Both", + param="w0", + value=0.6, + vary=True, + distribution="lognorm", + dist_params={"loc": 0.0, "scale": 0.6, "s": 0.2}, + ) + + moved = DistributionSetupWindow._center_drag_adjusted_entry( + entry, + target_center=0.9, + ) + assert moved.value == pytest.approx(0.9) + assert moved.dist_params["scale"] == pytest.approx(0.6) + assert moved.dist_params["s"] == pytest.approx(0.2) + assert moved.dist_params["loc"] == pytest.approx(0.3) + + narrowed = DistributionSetupWindow._peak_drag_adjusted_entry( + entry, + start_y=0.8, + target_y=1.4, + y_limits=(0.0, 2.0), + ) + widened = DistributionSetupWindow._peak_drag_adjusted_entry( + entry, + start_y=0.8, + target_y=0.2, + y_limits=(0.0, 2.0), + ) + assert narrowed.value == pytest.approx(entry.value) + assert narrowed.dist_params["loc"] == pytest.approx(0.0) + assert narrowed.dist_params["s"] < entry.dist_params["s"] + assert widened.dist_params["loc"] == pytest.approx(0.0) + assert widened.dist_params["s"] > entry.dist_params["s"] + + +def test_distribution_window_switches_lognorm_params_for_norm_and_uniform( qapp, ): del qapp @@ -2137,45 +3190,110 @@ def test_distribution_window_recommended_vary_selection_keeps_radius_params_off( DreamParameterEntry( structure="PbI2", motif="motif_A", - param_type="Both", + param_type="SAXS", param="w0", value=0.6, - vary=False, + vary=True, distribution="lognorm", dist_params={"loc": 0.0, "scale": 0.6, "s": 0.1}, - ), + ) + ] + ) + combo = window.table.cellWidget(0, 6) + + combo.setCurrentText("norm") + QApplication.processEvents() + + norm_params = json.loads(window.table.item(0, 7).text()) + low_column = _table_column_index(window.table, "Guide Low") + high_column = _table_column_index(window.table, "Guide High") + assert "s" not in norm_params + assert set(norm_params) == {"loc", "scale"} + assert float(window.table.item(0, low_column).text()) == pytest.approx( + -1.8 + ) + assert float(window.table.item(0, high_column).text()) == pytest.approx( + 1.8 + ) + assert window.figure.axes[0].get_title() == "w0: norm" + + combo.setCurrentText("uniform") + QApplication.processEvents() + + uniform_params = json.loads(window.table.item(0, 7).text()) + assert "s" not in uniform_params + assert set(uniform_params) == {"loc", "scale"} + assert float(window.table.item(0, low_column).text()) == pytest.approx( + 0.0, + abs=1e-12, + ) + assert float(window.table.item(0, high_column).text()) == pytest.approx( + 0.6 + ) + assert window.figure.axes[0].get_title() == "w0: uniform" + + +def test_distribution_window_displays_distribution_guide_bounds(qapp): + del qapp + window = DistributionSetupWindow( + [ DreamParameterEntry( - structure="", - motif="", + structure="PbI2", + motif="motif_A", param_type="SAXS", param="scale", value=1.0, - vary=False, + vary=True, distribution="norm", dist_params={"loc": 1.0, "scale": 0.2}, ), DreamParameterEntry( structure="PbI2", - motif="motif_A", + motif="motif_B", param_type="SAXS", - param="r_eff_w0", - value=4.0, + param="phi_solute", + value=0.45, vary=True, - distribution="norm", - dist_params={"loc": 4.0, "scale": 0.2}, + distribution="uniform", + dist_params={"loc": 0.2, "scale": 0.5}, + ), + DreamParameterEntry( + structure="PbI2", + motif="motif_C", + param_type="Both", + param="w0", + value=0.6, + vary=True, + distribution="lognorm", + dist_params={"loc": 0.0, "scale": 0.6, "s": 0.1}, ), ] ) - window.select_recommended_vary_button.click() - entries = window.current_entries() + low_column = _table_column_index(window.table, "Guide Low") + high_column = _table_column_index(window.table, "Guide High") + q_low = stats.norm.cdf(-3.0) + q_high = stats.norm.cdf(3.0) - assert entries[0].vary is True - assert entries[1].vary is True - assert entries[2].vary is False + assert float(window.table.item(0, low_column).text()) == pytest.approx(0.4) + assert float(window.table.item(0, high_column).text()) == pytest.approx( + 1.6 + ) + assert float(window.table.item(1, low_column).text()) == pytest.approx(0.2) + assert float(window.table.item(1, high_column).text()) == pytest.approx( + 0.7 + ) + assert float(window.table.item(2, low_column).text()) == pytest.approx( + stats.lognorm.ppf(q_low, s=0.1, loc=0.0, scale=0.6), + rel=1e-6, + ) + assert float(window.table.item(2, high_column).text()) == pytest.approx( + stats.lognorm.ppf(q_high, s=0.1, loc=0.0, scale=0.6), + rel=1e-6, + ) -def test_distribution_window_smart_prior_preset_tightens_and_relaxes_spreads( +def test_distribution_window_updates_distribution_guides_after_param_edit( qapp, ): del qapp @@ -2184,345 +3302,287 @@ def test_distribution_window_smart_prior_preset_tightens_and_relaxes_spreads( DreamParameterEntry( structure="PbI2", motif="motif_A", - param_type="Both", - param="w0", - value=0.6, - vary=True, - distribution="lognorm", - dist_params={"loc": 0.0, "scale": 0.6, "s": 0.2}, - ), - DreamParameterEntry( - structure="", - motif="", param_type="SAXS", param="scale", value=1.0, vary=True, distribution="norm", - dist_params={"loc": 1.0, "scale": 0.3}, - ), + dist_params={"loc": 1.0, "scale": 0.2}, + ) ] ) - window.smart_prior_preset_combo.setCurrentText("Strict") - window.apply_smart_prior_preset_button.click() - strict_entries = window.current_entries() - assert strict_entries[0].dist_params["s"] == pytest.approx(0.13) - assert strict_entries[1].dist_params["scale"] == pytest.approx(0.195) + low_column = _table_column_index(window.table, "Guide Low") + high_column = _table_column_index(window.table, "Guide High") - window.smart_prior_preset_combo.setCurrentText("Lenient") - window.apply_smart_prior_preset_button.click() - lenient_entries = window.current_entries() - assert lenient_entries[0].dist_params["s"] == pytest.approx(0.195) - assert lenient_entries[1].dist_params["scale"] == pytest.approx(0.2925) + window.table.item(0, 7).setText( + json.dumps({"loc": 1.0, "scale": 0.4}, sort_keys=True) + ) + QApplication.processEvents() + + assert float(window.table.item(0, low_column).text()) == pytest.approx( + -0.2 + ) + assert float(window.table.item(0, high_column).text()) == pytest.approx( + 2.2 + ) -def test_distribution_window_smart_prior_preset_can_target_selected_structures( - qapp, -): - del qapp +def test_distribution_window_tracks_current_row_and_rescales_plot(qapp): window = DistributionSetupWindow( [ DreamParameterEntry( - structure="Small", - motif="m1", - param_type="Both", - param="w0", - value=0.3, - vary=True, - distribution="norm", - dist_params={"loc": 0.3, "scale": 0.1}, - ), - DreamParameterEntry( - structure="Small", - motif="m1", + structure="PbI2", + motif="motif_A", param_type="SAXS", - param="r_eff_w0", - value=3.0, - vary=False, + param="scale", + value=1.0, + vary=True, distribution="norm", - dist_params={"loc": 3.0, "scale": 0.2}, + dist_params={"loc": 1.0, "scale": 0.2}, ), DreamParameterEntry( - structure="Large", - motif="m2", + structure="PbI2", + motif="motif_B", param_type="Both", - param="w1", - value=0.7, + param="w0", + value=1.0, vary=True, - distribution="norm", - dist_params={"loc": 0.7, "scale": 0.1}, + distribution="lognorm", + dist_params={"loc": 0.0, "scale": 1.0, "s": 2.0}, ), ] ) - scope_index = window.smart_prior_apply_scope_combo.findData("selected") - assert scope_index >= 0 - window.smart_prior_apply_scope_combo.setCurrentIndex(scope_index) - window.table.selectRow(0) - window.smart_prior_preset_combo.setCurrentText("Strict") - window.apply_smart_prior_preset_button.click() - entries = window.current_entries() + axis = window.figure.axes[0] + assert axis.get_title() == "scale: norm" + assert axis.get_xscale() == "linear" - assert entries[0].dist_params["scale"] == pytest.approx(0.065) - assert entries[1].dist_params["scale"] == pytest.approx(0.13) - assert entries[2].dist_params["scale"] == pytest.approx(0.1) - assert window.table.cellWidget(0, 8).currentText() == "Strict" - assert window.table.cellWidget(1, 8).currentText() == "Strict" - assert window.table.cellWidget(2, 8).currentText() == "Custom / Manual" + window.table.setCurrentCell(1, 0) + qapp.processEvents() + axis = window.figure.axes[0] + assert axis.get_title() == "w0: lognorm" + assert axis.get_xscale() == "log" + assert axis.get_xlim()[0] > 0.0 + assert axis.get_xlim()[1] < 100.0 -def test_distribution_window_size_aware_prior_preset_uses_effective_radii( + +def test_distribution_window_rescales_plot_after_distribution_mode_change( qapp, ): - del qapp window = DistributionSetupWindow( [ DreamParameterEntry( - structure="Small", - motif="m1", + structure="PbI2", + motif="motif_A", param_type="Both", param="w0", - value=0.3, - vary=True, - distribution="norm", - dist_params={"loc": 0.3, "scale": 0.1}, - ), - DreamParameterEntry( - structure="Large", - motif="m2", - param_type="Both", - param="w1", - value=0.7, + value=1.0, vary=True, - distribution="norm", - dist_params={"loc": 0.7, "scale": 0.1}, - ), - DreamParameterEntry( - structure="Small", - motif="m1", - param_type="SAXS", - param="r_eff_w0", - value=3.0, - vary=False, - distribution="norm", - dist_params={"loc": 3.0, "scale": 0.2}, - ), - DreamParameterEntry( - structure="Large", - motif="m2", - param_type="SAXS", - param="r_eff_w1", - value=8.0, - vary=False, - distribution="norm", - dist_params={"loc": 8.0, "scale": 0.2}, - ), + distribution="lognorm", + dist_params={"loc": 0.0, "scale": 1.0, "s": 2.0}, + ) ] ) - window.smart_prior_preset_combo.setCurrentText( - "Strict Small / Lenient Large" - ) - window.apply_smart_prior_preset_button.click() - entries = window.current_entries() + axis = window.figure.axes[0] + assert axis.get_xscale() == "log" + combo = window.table.cellWidget(0, 6) - assert entries[0].dist_params["scale"] == pytest.approx(0.065) - assert entries[1].dist_params["scale"] == pytest.approx(0.15) - assert entries[2].dist_params["scale"] == pytest.approx(0.13) - assert entries[3].dist_params["scale"] == pytest.approx(0.3) + combo.setCurrentText("uniform") + qapp.processEvents() + axis = window.figure.axes[0] + assert axis.get_title() == "w0: uniform" + assert axis.get_xscale() == "linear" + assert axis.get_xlim()[1] - axis.get_xlim()[0] < 2.0 -def test_distribution_window_size_aware_preset_sets_row_statuses_for_all_structures( - qapp, -): - del qapp + +def test_distribution_window_rescales_plot_after_param_edit(qapp): window = DistributionSetupWindow( [ DreamParameterEntry( - structure="Small", - motif="m1", - param_type="Both", - param="w0", - value=0.3, - vary=True, - distribution="norm", - dist_params={"loc": 0.3, "scale": 0.1}, - ), - DreamParameterEntry( - structure="Large", - motif="m2", - param_type="Both", - param="w1", - value=0.7, - vary=True, - distribution="norm", - dist_params={"loc": 0.7, "scale": 0.1}, - ), - DreamParameterEntry( - structure="Small", - motif="m1", - param_type="SAXS", - param="r_eff_w0", - value=3.0, - vary=False, - distribution="norm", - dist_params={"loc": 3.0, "scale": 0.2}, - ), - DreamParameterEntry( - structure="Large", - motif="m2", - param_type="SAXS", - param="r_eff_w1", - value=8.0, - vary=False, - distribution="norm", - dist_params={"loc": 8.0, "scale": 0.2}, - ), - DreamParameterEntry( - structure="", - motif="", + structure="PbI2", + motif="motif_A", param_type="SAXS", param="scale", value=1.0, vary=True, distribution="norm", - dist_params={"loc": 1.0, "scale": 0.3}, - ), + dist_params={"loc": 1.0, "scale": 1.0}, + ) ] ) - scope_index = window.smart_prior_apply_scope_combo.findData("selected") - assert scope_index >= 0 - window.smart_prior_apply_scope_combo.setCurrentIndex(scope_index) - window.table.selectRow(0) - window.smart_prior_preset_combo.setCurrentText( - "Strict Small / Lenient Large" + axis = window.figure.axes[0] + initial_xlim = axis.get_xlim() + initial_ylim = axis.get_ylim() + initial_baseline_y = _plot_lines_by_gid(axis, "reset-baseline")[ + 0 + ].get_ydata() + + window.table.item(0, 7).setText( + json.dumps({"loc": 1.0, "scale": 1.1}, sort_keys=True) ) - window.apply_smart_prior_preset_button.click() + qapp.processEvents() - assert window.table.cellWidget(0, 8).currentText() == "Strict" - assert window.table.cellWidget(1, 8).currentText() == "Lenient" - assert window.table.cellWidget(2, 8).currentText() == "Strict" - assert window.table.cellWidget(3, 8).currentText() == "Lenient" - assert window.table.cellWidget(4, 8).currentText() == "Proportional" + axis = window.figure.axes[0] + updated_xlim = axis.get_xlim() + updated_ylim = axis.get_ylim() + baseline_line = _plot_lines_by_gid(axis, "reset-baseline")[0] + current_line = _plot_lines_by_gid(axis, "current-distribution")[0] + assert axis.get_title() == "scale: norm" + assert axis.get_xscale() == "linear" + assert updated_xlim == pytest.approx(initial_xlim) + assert updated_ylim == pytest.approx(initial_ylim) + assert np.allclose(baseline_line.get_ydata(), initial_baseline_y) + assert not np.allclose( + current_line.get_ydata(), + baseline_line.get_ydata(), + ) + window.rescale_axes_button.click() + qapp.processEvents() -def test_distribution_window_row_status_can_override_single_structure_preset( + axis = window.figure.axes[0] + rescaled_xlim = axis.get_xlim() + rescaled_ylim = axis.get_ylim() + assert (rescaled_xlim[1] - rescaled_xlim[0]) > ( + initial_xlim[1] - initial_xlim[0] + ) + assert rescaled_ylim[1] < initial_ylim[1] + + +def test_distribution_window_auto_rescales_when_distribution_exits_window( qapp, ): - del qapp window = DistributionSetupWindow( [ DreamParameterEntry( - structure="Small", - motif="m1", - param_type="Both", - param="w0", - value=0.3, - vary=True, - distribution="norm", - dist_params={"loc": 0.3, "scale": 0.1}, - ), - DreamParameterEntry( - structure="Large", - motif="m2", - param_type="Both", - param="w1", - value=0.7, - vary=True, - distribution="norm", - dist_params={"loc": 0.7, "scale": 0.1}, - ), - DreamParameterEntry( - structure="Small", - motif="m1", + structure="PbI2", + motif="motif_A", param_type="SAXS", - param="r_eff_w0", - value=3.0, - vary=False, + param="scale", + value=1.0, + vary=True, distribution="norm", - dist_params={"loc": 3.0, "scale": 0.2}, - ), + dist_params={"loc": 1.0, "scale": 1.0}, + ) + ] + ) + + axis = window.figure.axes[0] + initial_xlim = axis.get_xlim() + initial_ylim = axis.get_ylim() + + window.table.item(0, 7).setText( + json.dumps({"loc": 1.0, "scale": 1.35}, sort_keys=True) + ) + qapp.processEvents() + + axis = window.figure.axes[0] + assert (axis.get_xlim()[1] - axis.get_xlim()[0]) > ( + initial_xlim[1] - initial_xlim[0] + ) + assert axis.get_ylim()[1] < initial_ylim[1] + + +def test_distribution_window_row_reset_restores_loaded_prior_settings(qapp): + window = DistributionSetupWindow( + [ DreamParameterEntry( - structure="Large", - motif="m2", + structure="PbI2", + motif="motif_A", param_type="SAXS", - param="r_eff_w1", - value=8.0, + param="scale", + value=1.0, vary=False, distribution="norm", - dist_params={"loc": 8.0, "scale": 0.2}, - ), + dist_params={"loc": 1.0, "scale": 0.2}, + smart_preset_status="strict", + ) ] ) - window.smart_prior_preset_combo.setCurrentText( - "Strict Small / Lenient Large" + reset_col = _table_column_index(window.table, "Reset") + low_column = _table_column_index(window.table, "Guide Low") + high_column = _table_column_index(window.table, "Guide High") + distribution_combo = window.table.cellWidget(0, 6) + vary_box = window.table.cellWidget(0, 5) + reset_button = window.table.cellWidget(0, reset_col) + + window.table.item(0, 4).setText("2.5") + vary_box.setChecked(True) + distribution_combo.setCurrentText("uniform") + window.table.item(0, 7).setText( + json.dumps({"loc": 0.0, "scale": 4.0}, sort_keys=True) ) - window.apply_smart_prior_preset_button.click() + qapp.processEvents() - small_status_combo = window.table.cellWidget(0, 8) - very_lenient_index = small_status_combo.findData("very_lenient") - assert very_lenient_index >= 0 - small_status_combo.setCurrentIndex(very_lenient_index) - entries = window.current_entries() + assert reset_button is not None + reset_button.click() + qapp.processEvents() - assert entries[0].dist_params["scale"] == pytest.approx(0.14625) - assert entries[2].dist_params["scale"] == pytest.approx(0.2925) - assert entries[1].dist_params["scale"] == pytest.approx(0.15) - assert entries[3].dist_params["scale"] == pytest.approx(0.3) - assert window.table.cellWidget(0, 8).currentText() == "Very Lenient" - assert window.table.cellWidget(2, 8).currentText() == "Very Lenient" - assert window.table.cellWidget(1, 8).currentText() == "Lenient" + entry = window.current_entries()[0] + assert entry.value == pytest.approx(1.0) + assert entry.vary is False + assert entry.distribution == "norm" + assert entry.dist_params == pytest.approx({"loc": 1.0, "scale": 0.2}) + assert entry.smart_preset_status == "strict" + assert float(window.table.item(0, low_column).text()) == pytest.approx(0.4) + assert float(window.table.item(0, high_column).text()) == pytest.approx( + 1.6 + ) + assert window.figure.axes[0].get_title() == "scale: norm" -def test_distribution_window_warns_when_effective_radius_is_set_to_vary( +def test_distribution_window_row_reset_preserves_original_baseline_after_preset( qapp, - monkeypatch, ): - del qapp window = DistributionSetupWindow( [ DreamParameterEntry( structure="PbI2", motif="motif_A", param_type="SAXS", - param="r_eff_w0", - value=4.0, - vary=False, + param="scale", + value=1.0, + vary=True, distribution="norm", - dist_params={"loc": 4.0, "scale": 0.2}, + dist_params={"loc": 1.0, "scale": 0.2}, ) ] ) - warnings: list[tuple[str, str]] = [] - monkeypatch.setattr( - "saxshell.saxs.ui.distribution_window.QMessageBox.warning", - lambda _parent, title, message: warnings.append((title, message)), + reset_col = _table_column_index(window.table, "Reset") + window.smart_prior_preset_combo.setCurrentIndex( + window.smart_prior_preset_combo.findData("lenient") ) + window.apply_smart_prior_preset_button.click() + qapp.processEvents() - vary_box = window.table.cellWidget(0, 5) - assert isinstance(vary_box, QCheckBox) - vary_box.setChecked(True) + adjusted_entry = window.current_entries()[0] + assert adjusted_entry.dist_params["scale"] == pytest.approx(0.3) - assert warnings - assert warnings[-1][0] == "Effective radius variation warning" - assert "r_eff_w0" in warnings[-1][1] - assert "not recommended to vary effective-radius parameters" in ( - warnings[-1][1] - ) + reset_button = window.table.cellWidget(0, reset_col) + assert reset_button is not None + reset_button.click() + qapp.processEvents() + reset_entry = window.current_entries()[0] + assert reset_entry.dist_params["scale"] == pytest.approx(0.2) + assert reset_entry.smart_preset_status == "custom" -def test_distribution_window_previews_all_weight_priors_in_shared_plot(qapp): + +def test_distribution_window_can_toggle_all_vary_flags(qapp): del qapp window = DistributionSetupWindow( [ DreamParameterEntry( structure="PbI2", motif="motif_A", - param_type="Both", + param_type="SAXS", param="w0", value=0.6, vary=True, @@ -2532,1500 +3592,2540 @@ def test_distribution_window_previews_all_weight_priors_in_shared_plot(qapp): DreamParameterEntry( structure="PbI2", motif="motif_B", - param_type="Both", - param="w1", - value=0.4, - vary=True, - distribution="norm", - dist_params={"loc": 0.4, "scale": 0.05}, - ), - DreamParameterEntry( - structure="", - motif="", param_type="SAXS", param="scale", value=1.0, - vary=True, + vary=False, distribution="norm", dist_params={"loc": 1.0, "scale": 0.2}, ), ] ) - window.preview_weight_priors_button.click() + window.set_all_vary_off_button.click() + assert all(not entry.vary for entry in window.current_entries()) - assert window._weight_preview_window is not None - assert window._weight_preview_window.isVisible() - axis = window._weight_preview_window.figure.axes[0] - assert axis.get_title() == "Weight prior distributions" - assert axis.get_xlabel() == "Value" - assert axis.get_ylabel() == "Density" - plotted_labels = [line.get_label() for line in axis.get_lines()] - assert plotted_labels == ["w0 (PbI2)", "w1 (PbI2)"] + window.set_all_vary_on_button.click() + assert all(entry.vary for entry in window.current_entries()) -def test_template_dropdowns_use_display_names_and_tooltips(qapp): +def test_distribution_window_recommended_vary_selection_keeps_radius_params_off( + qapp, +): del qapp - tab = ProjectSetupTab() - basic_spec = load_template_spec("template_likelihood_monosq") - decoupled_spec = load_template_spec( - "template_pd_likelihood_monosq_decoupled" + window = DistributionSetupWindow( + [ + DreamParameterEntry( + structure="PbI2", + motif="motif_A", + param_type="Both", + param="w0", + value=0.6, + vary=False, + distribution="lognorm", + dist_params={"loc": 0.0, "scale": 0.6, "s": 0.1}, + ), + DreamParameterEntry( + structure="", + motif="", + param_type="SAXS", + param="scale", + value=1.0, + vary=False, + distribution="norm", + dist_params={"loc": 1.0, "scale": 0.2}, + ), + DreamParameterEntry( + structure="PbI2", + motif="motif_A", + param_type="SAXS", + param="r_eff_w0", + value=4.0, + vary=True, + distribution="norm", + dist_params={"loc": 4.0, "scale": 0.2}, + ), + ] ) - tab.set_available_templates([basic_spec, decoupled_spec], basic_spec.name) + window.select_recommended_vary_button.click() + entries = window.current_entries() - assert tab.template_combo.itemText(0).startswith("MonoSQ Basic") - assert ( - tab.template_combo.itemData(0, Qt.ItemDataRole.ToolTipRole) - == basic_spec.description - ) - assert tab.selected_template_name() == basic_spec.name + assert entries[0].vary is True + assert entries[1].vary is True + assert entries[2].vary is False -def test_template_dropdowns_hide_deprecated_by_default_but_load_selected_deprecated( +def test_distribution_window_smart_prior_preset_tightens_and_relaxes_spreads( qapp, - tmp_path, ): del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) - window = SAXSMainWindow(initial_project_dir=project_dir) - - project_template_names = { - str(window.project_setup_tab.template_combo.itemData(index) or "") - for index in range(window.project_setup_tab.template_combo.count()) + window = DistributionSetupWindow( + [ + DreamParameterEntry( + structure="PbI2", + motif="motif_A", + param_type="Both", + param="w0", + value=0.6, + vary=True, + distribution="lognorm", + dist_params={"loc": 0.0, "scale": 0.6, "s": 0.2}, + ), + DreamParameterEntry( + structure="", + motif="", + param_type="SAXS", + param="scale", + value=1.0, + vary=True, + distribution="norm", + dist_params={"loc": 1.0, "scale": 0.3}, + ), + ] + ) + + window.smart_prior_preset_combo.setCurrentText("Strict") + window.apply_smart_prior_preset_button.click() + strict_entries = window.current_entries() + assert strict_entries[0].dist_params["s"] == pytest.approx(0.13) + assert strict_entries[1].dist_params["scale"] == pytest.approx(0.195) + + window.smart_prior_preset_combo.setCurrentText("Lenient") + window.apply_smart_prior_preset_button.click() + lenient_entries = window.current_entries() + assert lenient_entries[0].dist_params["s"] == pytest.approx(0.195) + assert lenient_entries[1].dist_params["scale"] == pytest.approx(0.2925) + + +def test_distribution_window_smart_prior_preset_can_target_selected_structures( + qapp, +): + del qapp + window = DistributionSetupWindow( + [ + DreamParameterEntry( + structure="Small", + motif="m1", + param_type="Both", + param="w0", + value=0.3, + vary=True, + distribution="norm", + dist_params={"loc": 0.3, "scale": 0.1}, + ), + DreamParameterEntry( + structure="Small", + motif="m1", + param_type="SAXS", + param="r_eff_w0", + value=3.0, + vary=False, + distribution="norm", + dist_params={"loc": 3.0, "scale": 0.2}, + ), + DreamParameterEntry( + structure="Large", + motif="m2", + param_type="Both", + param="w1", + value=0.7, + vary=True, + distribution="norm", + dist_params={"loc": 0.7, "scale": 0.1}, + ), + ] + ) + + scope_index = window.smart_prior_apply_scope_combo.findData("selected") + assert scope_index >= 0 + window.smart_prior_apply_scope_combo.setCurrentIndex(scope_index) + window.table.selectRow(0) + window.smart_prior_preset_combo.setCurrentText("Strict") + window.apply_smart_prior_preset_button.click() + entries = window.current_entries() + + assert entries[0].dist_params["scale"] == pytest.approx(0.065) + assert entries[1].dist_params["scale"] == pytest.approx(0.13) + assert entries[2].dist_params["scale"] == pytest.approx(0.1) + assert window.table.cellWidget(0, 8).currentText() == "Strict" + assert window.table.cellWidget(1, 8).currentText() == "Strict" + assert window.table.cellWidget(2, 8).currentText() == "Custom / Manual" + + +def test_distribution_window_size_aware_prior_preset_uses_effective_radii( + qapp, +): + del qapp + window = DistributionSetupWindow( + [ + DreamParameterEntry( + structure="Small", + motif="m1", + param_type="Both", + param="w0", + value=0.3, + vary=True, + distribution="norm", + dist_params={"loc": 0.3, "scale": 0.1}, + ), + DreamParameterEntry( + structure="Large", + motif="m2", + param_type="Both", + param="w1", + value=0.7, + vary=True, + distribution="norm", + dist_params={"loc": 0.7, "scale": 0.1}, + ), + DreamParameterEntry( + structure="Small", + motif="m1", + param_type="SAXS", + param="r_eff_w0", + value=3.0, + vary=False, + distribution="norm", + dist_params={"loc": 3.0, "scale": 0.2}, + ), + DreamParameterEntry( + structure="Large", + motif="m2", + param_type="SAXS", + param="r_eff_w1", + value=8.0, + vary=False, + distribution="norm", + dist_params={"loc": 8.0, "scale": 0.2}, + ), + ] + ) + + window.smart_prior_preset_combo.setCurrentText( + "Strict Small / Lenient Large" + ) + window.apply_smart_prior_preset_button.click() + entries = window.current_entries() + + assert entries[0].dist_params["scale"] == pytest.approx(0.065) + assert entries[1].dist_params["scale"] == pytest.approx(0.15) + assert entries[2].dist_params["scale"] == pytest.approx(0.13) + assert entries[3].dist_params["scale"] == pytest.approx(0.3) + + +def test_distribution_window_size_aware_preset_sets_row_statuses_for_all_structures( + qapp, +): + del qapp + window = DistributionSetupWindow( + [ + DreamParameterEntry( + structure="Small", + motif="m1", + param_type="Both", + param="w0", + value=0.3, + vary=True, + distribution="norm", + dist_params={"loc": 0.3, "scale": 0.1}, + ), + DreamParameterEntry( + structure="Large", + motif="m2", + param_type="Both", + param="w1", + value=0.7, + vary=True, + distribution="norm", + dist_params={"loc": 0.7, "scale": 0.1}, + ), + DreamParameterEntry( + structure="Small", + motif="m1", + param_type="SAXS", + param="r_eff_w0", + value=3.0, + vary=False, + distribution="norm", + dist_params={"loc": 3.0, "scale": 0.2}, + ), + DreamParameterEntry( + structure="Large", + motif="m2", + param_type="SAXS", + param="r_eff_w1", + value=8.0, + vary=False, + distribution="norm", + dist_params={"loc": 8.0, "scale": 0.2}, + ), + DreamParameterEntry( + structure="", + motif="", + param_type="SAXS", + param="scale", + value=1.0, + vary=True, + distribution="norm", + dist_params={"loc": 1.0, "scale": 0.3}, + ), + ] + ) + + scope_index = window.smart_prior_apply_scope_combo.findData("selected") + assert scope_index >= 0 + window.smart_prior_apply_scope_combo.setCurrentIndex(scope_index) + window.table.selectRow(0) + window.smart_prior_preset_combo.setCurrentText( + "Strict Small / Lenient Large" + ) + window.apply_smart_prior_preset_button.click() + + assert window.table.cellWidget(0, 8).currentText() == "Strict" + assert window.table.cellWidget(1, 8).currentText() == "Lenient" + assert window.table.cellWidget(2, 8).currentText() == "Strict" + assert window.table.cellWidget(3, 8).currentText() == "Lenient" + assert window.table.cellWidget(4, 8).currentText() == "Proportional" + + +def test_distribution_window_row_status_can_override_single_structure_preset( + qapp, +): + del qapp + window = DistributionSetupWindow( + [ + DreamParameterEntry( + structure="Small", + motif="m1", + param_type="Both", + param="w0", + value=0.3, + vary=True, + distribution="norm", + dist_params={"loc": 0.3, "scale": 0.1}, + ), + DreamParameterEntry( + structure="Large", + motif="m2", + param_type="Both", + param="w1", + value=0.7, + vary=True, + distribution="norm", + dist_params={"loc": 0.7, "scale": 0.1}, + ), + DreamParameterEntry( + structure="Small", + motif="m1", + param_type="SAXS", + param="r_eff_w0", + value=3.0, + vary=False, + distribution="norm", + dist_params={"loc": 3.0, "scale": 0.2}, + ), + DreamParameterEntry( + structure="Large", + motif="m2", + param_type="SAXS", + param="r_eff_w1", + value=8.0, + vary=False, + distribution="norm", + dist_params={"loc": 8.0, "scale": 0.2}, + ), + ] + ) + + window.smart_prior_preset_combo.setCurrentText( + "Strict Small / Lenient Large" + ) + window.apply_smart_prior_preset_button.click() + + small_status_combo = window.table.cellWidget(0, 8) + very_lenient_index = small_status_combo.findData("very_lenient") + assert very_lenient_index >= 0 + small_status_combo.setCurrentIndex(very_lenient_index) + entries = window.current_entries() + + assert entries[0].dist_params["scale"] == pytest.approx(0.14625) + assert entries[2].dist_params["scale"] == pytest.approx(0.2925) + assert entries[1].dist_params["scale"] == pytest.approx(0.15) + assert entries[3].dist_params["scale"] == pytest.approx(0.3) + assert window.table.cellWidget(0, 8).currentText() == "Very Lenient" + assert window.table.cellWidget(2, 8).currentText() == "Very Lenient" + assert window.table.cellWidget(1, 8).currentText() == "Lenient" + + +def test_distribution_window_warns_when_effective_radius_is_set_to_vary( + qapp, + monkeypatch, +): + del qapp + window = DistributionSetupWindow( + [ + DreamParameterEntry( + structure="PbI2", + motif="motif_A", + param_type="SAXS", + param="r_eff_w0", + value=4.0, + vary=False, + distribution="norm", + dist_params={"loc": 4.0, "scale": 0.2}, + ) + ] + ) + + warnings: list[tuple[str, str]] = [] + monkeypatch.setattr( + "saxshell.saxs.ui.distribution_window.QMessageBox.warning", + lambda _parent, title, message: warnings.append((title, message)), + ) + + vary_box = window.table.cellWidget(0, 5) + assert isinstance(vary_box, QCheckBox) + vary_box.setChecked(True) + + assert warnings + assert warnings[-1][0] == "Effective radius variation warning" + assert "r_eff_w0" in warnings[-1][1] + assert "not recommended to vary effective-radius parameters" in ( + warnings[-1][1] + ) + + +def test_distribution_window_preview_defaults_to_weight_parameters(qapp): + del qapp + window = DistributionSetupWindow( + [ + DreamParameterEntry( + structure="PbI2", + motif="motif_A", + param_type="Both", + param="w0", + value=0.6, + vary=True, + distribution="lognorm", + dist_params={"loc": 0.0, "scale": 0.6, "s": 0.1}, + ), + DreamParameterEntry( + structure="PbI2", + motif="motif_B", + param_type="Both", + param="w1", + value=0.4, + vary=True, + distribution="norm", + dist_params={"loc": 0.4, "scale": 0.05}, + ), + DreamParameterEntry( + structure="", + motif="", + param_type="SAXS", + param="scale", + value=1.0, + vary=True, + distribution="norm", + dist_params={"loc": 1.0, "scale": 0.2}, + ), + ] + ) + + assert isinstance(window.toolbar, NavigationToolbar2QT) + window.preview_weight_priors_button.click() + + assert window._weight_preview_window is not None + assert window._weight_preview_window.isVisible() + assert isinstance( + window._weight_preview_window.toolbar, NavigationToolbar2QT + ) + checkbox_states = { + entry.param: checkbox.isChecked() + for entry, checkbox in window._weight_preview_window._parameter_checkboxes + } + assert checkbox_states == {"w0": True, "w1": True, "scale": False} + axis = window._weight_preview_window.figure.axes[0] + assert axis.get_title() == "Prior distributions" + assert axis.get_xlabel() == "Value" + assert axis.get_ylabel() == "Density" + plotted_labels = [line.get_label() for line in axis.get_lines()] + assert plotted_labels == ["w0 (PbI2)", "w1 (PbI2)"] + + +def test_distribution_window_preview_can_toggle_non_weight_parameters(qapp): + window = DistributionSetupWindow( + [ + DreamParameterEntry( + structure="PbI2", + motif="motif_A", + param_type="Both", + param="w0", + value=0.6, + vary=True, + distribution="lognorm", + dist_params={"loc": 0.0, "scale": 0.6, "s": 0.1}, + ), + DreamParameterEntry( + structure="PbI2", + motif="motif_B", + param_type="Both", + param="w1", + value=0.4, + vary=True, + distribution="norm", + dist_params={"loc": 0.4, "scale": 0.05}, + ), + DreamParameterEntry( + structure="", + motif="", + param_type="SAXS", + param="scale", + value=1.0, + vary=True, + distribution="norm", + dist_params={"loc": 1.0, "scale": 0.2}, + ), + ] + ) + + window.preview_weight_priors_button.click() + + assert window._weight_preview_window is not None + scale_checkbox = next( + checkbox + for entry, checkbox in window._weight_preview_window._parameter_checkboxes + if entry.param == "scale" + ) + scale_checkbox.setChecked(True) + qapp.processEvents() + + axis = window._weight_preview_window.figure.axes[0] + plotted_labels = [line.get_label() for line in axis.get_lines()] + assert plotted_labels == ["w0 (PbI2)", "w1 (PbI2)", "scale"] + + +def test_template_dropdowns_use_display_names_and_tooltips(qapp): + del qapp + tab = ProjectSetupTab() + basic_spec = load_template_spec("template_likelihood_monosq") + decoupled_spec = load_template_spec( + "template_pd_likelihood_monosq_decoupled" + ) + + tab.set_available_templates([basic_spec, decoupled_spec], basic_spec.name) + + assert tab.template_combo.itemText(0).startswith("MonoSQ Basic") + assert ( + tab.template_combo.itemData(0, Qt.ItemDataRole.ToolTipRole) + == basic_spec.description + ) + assert tab.selected_template_name() == basic_spec.name + + +def test_template_dropdowns_hide_deprecated_by_default_but_load_selected_deprecated( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + + project_template_names = { + str(window.project_setup_tab.template_combo.itemData(index) or "") + for index in range(window.project_setup_tab.template_combo.count()) + } + prefit_template_names = { + str(window.prefit_tab.template_combo.itemData(index) or "") + for index in range(window.prefit_tab.template_combo.count()) + } + + assert not window.project_setup_tab.show_deprecated_templates() + assert not window.prefit_tab.show_deprecated_templates() + assert window.project_setup_tab.selected_template_name() == ( + "template_pd_likelihood_monosq_decoupled" + ) + assert window.prefit_tab.selected_template_name() == ( + "template_pd_likelihood_monosq_decoupled" + ) + assert "template_pd_likelihood_monosq_decoupled" in project_template_names + assert "template_pd_likelihood_monosq_decoupled" in prefit_template_names + assert "template_likelihood_monosq" not in project_template_names + assert "template_likelihood_monosq" not in prefit_template_names + assert "template_pydream_poly_lma_hs_legacy" not in project_template_names + assert "template_pydream_poly_lma_hs_legacy" not in prefit_template_names + window.close() + + +def test_template_dropdowns_can_show_deprecated_and_stay_synced( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + + window.project_setup_tab.show_deprecated_templates_checkbox.setChecked( + True + ) + + project_template_names = { + str(window.project_setup_tab.template_combo.itemData(index) or "") + for index in range(window.project_setup_tab.template_combo.count()) + } + prefit_template_names = { + str(window.prefit_tab.template_combo.itemData(index) or "") + for index in range(window.prefit_tab.template_combo.count()) + } + + assert window.project_setup_tab.show_deprecated_templates() + assert window.prefit_tab.show_deprecated_templates() + assert "template_likelihood_monosq" in project_template_names + assert "template_likelihood_monosq" in prefit_template_names + assert "template_pydream_poly_lma_hs_legacy" in project_template_names + assert "template_pydream_poly_lma_hs_legacy" in prefit_template_names + + window.prefit_tab.show_deprecated_templates_checkbox.setChecked(False) + + project_template_names = { + str(window.project_setup_tab.template_combo.itemData(index) or "") + for index in range(window.project_setup_tab.template_combo.count()) } prefit_template_names = { str(window.prefit_tab.template_combo.itemData(index) or "") for index in range(window.prefit_tab.template_combo.count()) - } + } + + assert not window.project_setup_tab.show_deprecated_templates() + assert not window.prefit_tab.show_deprecated_templates() + assert window.project_setup_tab.selected_template_name() == ( + "template_pd_likelihood_monosq_decoupled" + ) + assert window.prefit_tab.selected_template_name() == ( + "template_pd_likelihood_monosq_decoupled" + ) + assert "template_likelihood_monosq" not in project_template_names + assert "template_likelihood_monosq" not in prefit_template_names + assert "template_pd_likelihood_monosq_decoupled" in project_template_names + assert "template_pd_likelihood_monosq_decoupled" in prefit_template_names + window.close() + + +def test_install_model_dialog_collects_model_inputs(qapp, tmp_path): + del qapp + candidate_template = tmp_path / "candidate_install_model.py" + candidate_template.write_text( + "import numpy as np\n" + "# model_lmfit: lmfit_model_profile\n" + "# model_pydream: log_likelihood_candidate\n" + "# inputs_lmfit: q, solvent_data, model_data, params\n" + "# inputs_pydream: q_values, experimental_intensities, " + "solvent_intensities, theoretical_intensities, params\n" + "# param_columns: Structure, Motif, Param, Value, Vary, Min, Max\n" + "# param: scale,1.0,True,0.0,10.0\n" + "def lmfit_model_profile(q, solvent_data, model_data, **params):\n" + " del q, solvent_data\n" + " return params['scale'] * np.asarray(model_data[0], dtype=float)\n" + "def log_likelihood_candidate(params):\n" + " del params\n" + " return -1.0\n", + encoding="utf-8", + ) + dialog = InstallModelDialog() + dialog.model_name_edit.setText("Dialog Candidate Model") + dialog.template_path_edit.setText(str(candidate_template)) + dialog.description_edit.setPlainText("Dialog-installed candidate.") + + request = dialog.selected_request() + + assert request == TemplateInstallRequest( + model_name="Dialog Candidate Model", + template_path=candidate_template.resolve(), + model_description="Dialog-installed candidate.", + ) + + +def test_install_model_template_installs_and_refreshes_template_lists( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + install_dir = tmp_path / "installed_templates" + request = TemplateInstallRequest( + model_name="Installed Candidate Model", + template_path=( + Path( + "tests/template_candidates/valid_installable_model.py" + ).resolve() + ), + model_description="Installed from the Project Setup dialog.", + ) + info_messages: list[tuple[str, str]] = [] + + monkeypatch.setattr( + window, + "_prompt_template_install_request", + lambda: request, + ) + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.install_template_candidate", + lambda template_path, **kwargs: install_template_candidate( + template_path, + destination_dir=install_dir, + **kwargs, + ), + ) + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.list_template_specs", + lambda **kwargs: list_template_specs(install_dir, **kwargs), + ) + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.QMessageBox.information", + lambda _parent, title, message: info_messages.append((title, message)), + ) + + window.project_setup_tab.install_model_button.click() + + assert (install_dir / "template_installed_candidate_model.py").is_file() + assert (install_dir / "template_installed_candidate_model.json").is_file() + assert any( + window.project_setup_tab.template_combo.itemText(index) + == "Installed Candidate Model" + for index in range(window.project_setup_tab.template_combo.count()) + ) + assert any( + window.prefit_tab.template_combo.itemText(index) + == "Installed Candidate Model" + for index in range(window.prefit_tab.template_combo.count()) + ) + assert info_messages + assert info_messages[0][0] == "Model installed" + assert "Installed Candidate Model" in info_messages[0][1] + + +def test_install_model_template_surfaces_validation_failures( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + install_dir = tmp_path / "installed_templates" + request = TemplateInstallRequest( + model_name="Broken Dream Callable Model", + template_path=( + Path( + "tests/template_candidates/fail_missing_dream_callable_model.py" + ).resolve() + ), + model_description="Expected to fail in the validation step.", + ) + errors: list[tuple[str, str]] = [] + + monkeypatch.setattr( + window, + "_prompt_template_install_request", + lambda: request, + ) + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.install_template_candidate", + lambda template_path, **kwargs: install_template_candidate( + template_path, + destination_dir=install_dir, + **kwargs, + ), + ) + monkeypatch.setattr( + window, + "_show_error", + lambda title, message: errors.append((title, message)), + ) + + window.project_setup_tab.install_model_button.click() + + assert errors + assert errors[0][0] == "Install model failed" + assert "Missing callable log_likelihood_candidate" in errors[0][1] + + +def test_project_setup_empty_preview_message_is_wrapped(qapp): + del qapp + tab = ProjectSetupTab() + tab.draw_component_plot(None) + + preview_axis = tab.component_figure.axes[0] + text_labels = [text.get_text() for text in preview_axis.texts] + + assert any( + "Select experimental data and build SAXS" in label + and "averaged cluster profiles." in label + and "\n" in label + for label in text_labels + ) + + +def test_run_dream_requires_parameter_map_saved_in_session( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + prefit = SAXSPrefitWorkflow(project_dir) + prefit.save_fit(prefit.parameter_entries) + workflow = SAXSDreamWorkflow(project_dir) + workflow.create_default_parameter_map() + window = SAXSMainWindow(initial_project_dir=project_dir) + + blinked = {"value": False} + monkeypatch.setattr( + window.dream_tab, + "blink_edit_priors_button", + lambda: blinked.update({"value": True}), + ) + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.SAXSDreamWorkflow.run_bundle", + lambda self, bundle: (_ for _ in ()).throw( + AssertionError("run_bundle should not be called") + ), + ) + + window.run_dream_bundle() + + assert blinked["value"] + assert "Review the priors in Edit Priors" in ( + window.dream_tab.output_box.toPlainText() + ) + + +def test_run_dream_requires_written_runtime_bundle( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + prefit = SAXSPrefitWorkflow(project_dir) + prefit.save_fit(prefit.parameter_entries) + workflow = SAXSDreamWorkflow(project_dir) + entries = workflow.create_default_parameter_map(persist=True) + window = SAXSMainWindow(initial_project_dir=project_dir) + window._save_distribution_entries(entries) + + blinked = {"value": False} + errors: list[tuple[str, str]] = [] + monkeypatch.setattr( + window.dream_tab, + "blink_write_bundle_button", + lambda: blinked.update({"value": True}), + ) + monkeypatch.setattr( + window, + "_show_error", + lambda title, message: errors.append((title, message)), + ) + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.SAXSDreamWorkflow.run_bundle", + lambda self, bundle: (_ for _ in ()).throw( + AssertionError("run_bundle should not be called") + ), + ) + + window.run_dream_bundle() + + assert blinked["value"] + assert errors == [ + ( + "Runtime Bundle not generated", + "Runtime Bundle not generated. Click Write Runtime Bundle before running DREAM.", + ) + ] + assert "Runtime Bundle not generated" in ( + window.dream_tab.output_box.toPlainText() + ) + + +def test_preview_runtime_bundle_opens_latest_written_script( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + prefit = SAXSPrefitWorkflow(project_dir) + prefit.save_fit(prefit.parameter_entries) + workflow = SAXSDreamWorkflow(project_dir) + entries = workflow.create_default_parameter_map(persist=True) + bundle = workflow.create_runtime_bundle(entries=entries) + window = SAXSMainWindow(initial_project_dir=project_dir) + window._last_written_dream_bundle = bundle + + opener = RuntimeBundleOpener( + label="Fake Editor", + stored_value="/Applications/FakeEditor.app", + launch_target="/Applications/FakeEditor.app", + launch_mode="mac_app", + ) + launched: dict[str, object] = {} + + monkeypatch.setattr( + window, + "_available_runtime_bundle_openers", + lambda: [opener], + ) + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.QInputDialog.getItem", + lambda *args, **kwargs: ("Fake Editor", True), + ) + monkeypatch.setattr( + window, + "_launch_runtime_bundle_with_opener", + lambda script_path, selected_opener: launched.update( + { + "path": str(script_path), + "label": selected_opener.label, + "stored_value": selected_opener.stored_value, + } + ), + ) + + window.preview_dream_runtime_bundle() + + assert launched["path"] == str(bundle.runtime_script_path) + assert launched["label"] == "Fake Editor" + assert "Opened DREAM runtime bundle preview" in ( + window.dream_tab.output_box.toPlainText() + ) + reloaded_settings = SAXSProjectManager().load_project(project_dir) + assert reloaded_settings.runtime_bundle_opener == opener.stored_value - assert not window.project_setup_tab.show_deprecated_templates() - assert not window.prefit_tab.show_deprecated_templates() - assert window.project_setup_tab.selected_template_name() == ( - "template_pd_likelihood_monosq_decoupled" + +def test_preview_runtime_bundle_reuses_saved_project_opener( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + prefit = SAXSPrefitWorkflow(project_dir) + prefit.save_fit(prefit.parameter_entries) + workflow = SAXSDreamWorkflow(project_dir) + entries = workflow.create_default_parameter_map(persist=True) + bundle = workflow.create_runtime_bundle(entries=entries) + + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.runtime_bundle_opener = "/Applications/SavedEditor.app" + manager.save_project(settings) + + window = SAXSMainWindow(initial_project_dir=project_dir) + window._last_written_dream_bundle = bundle + launched: dict[str, object] = {} + + monkeypatch.setattr( + window, + "_available_runtime_bundle_openers", + lambda: [ + RuntimeBundleOpener( + label="Saved Editor", + stored_value="/Applications/SavedEditor.app", + launch_target="/Applications/SavedEditor.app", + launch_mode="mac_app", + ) + ], ) - assert window.prefit_tab.selected_template_name() == ( - "template_pd_likelihood_monosq_decoupled" + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.QInputDialog.getItem", + lambda *args, **kwargs: pytest.fail( + "The opener chooser should not appear for a saved project opener." + ), + ) + monkeypatch.setattr( + window, + "_launch_runtime_bundle_with_opener", + lambda script_path, selected_opener: launched.update( + { + "path": str(script_path), + "label": selected_opener.label, + "stored_value": selected_opener.stored_value, + } + ), ) - assert "template_pd_likelihood_monosq_decoupled" in project_template_names - assert "template_pd_likelihood_monosq_decoupled" in prefit_template_names - assert "template_likelihood_monosq" not in project_template_names - assert "template_likelihood_monosq" not in prefit_template_names - assert "template_pydream_poly_lma_hs_legacy" not in project_template_names - assert "template_pydream_poly_lma_hs_legacy" not in prefit_template_names - window.close() + window.preview_dream_runtime_bundle() -def test_template_dropdowns_can_show_deprecated_and_stay_synced( - qapp, - tmp_path, + assert launched["path"] == str(bundle.runtime_script_path) + assert launched["stored_value"] == "/Applications/SavedEditor.app" + + +def test_run_dream_shows_progress_and_popup_can_be_closed( + qapp, tmp_path, monkeypatch ): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) + prefit = SAXSPrefitWorkflow(project_dir) + prefit.save_fit(prefit.parameter_entries) + workflow = SAXSDreamWorkflow(project_dir) + entries = workflow.create_default_parameter_map(persist=True) window = SAXSMainWindow(initial_project_dir=project_dir) + window._save_distribution_entries(entries) + window.write_dream_bundle() - window.project_setup_tab.show_deprecated_templates_checkbox.setChecked( - True + def _fake_run_bundle( + self, + bundle, + *, + output_callback=None, + output_interval_seconds=None, + ): + del output_interval_seconds + if output_callback is not None: + output_callback("DREAM sampler: initialization complete") + time.sleep(0.15) + if output_callback is not None: + output_callback("DREAM sampler: collecting posterior samples") + metadata = json.loads(bundle.metadata_path.read_text(encoding="utf-8")) + active_count = len(metadata["active_parameter_entries"]) + active_values = np.asarray( + [ + float(entry["value"]) + for entry in metadata["active_parameter_entries"] + ], + dtype=float, + ) + sampled_params = [] + log_ps = [] + for chain_index in range(2): + chain_samples = [] + chain_logps = [] + for step_index in range(4): + adjustment = (chain_index + 1) * (step_index + 1) * 0.01 + chain_samples.append(active_values[:active_count] + adjustment) + chain_logps.append(-5.0 + chain_index + step_index * 0.5) + sampled_params.append(chain_samples) + log_ps.append(chain_logps) + np.save( + bundle.run_dir / "dream_sampled_params.npy", + np.asarray(sampled_params, dtype=float), + ) + np.save( + bundle.run_dir / "dream_log_ps.npy", + np.asarray(log_ps, dtype=float)[..., np.newaxis], + ) + return { + "sampled_params_path": str( + bundle.run_dir / "dream_sampled_params.npy" + ), + "log_ps_path": str(bundle.run_dir / "dream_log_ps.npy"), + } + + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.SAXSDreamWorkflow.run_bundle", + _fake_run_bundle, ) - project_template_names = { - str(window.project_setup_tab.template_combo.itemData(index) or "") - for index in range(window.project_setup_tab.template_combo.count()) - } - prefit_template_names = { - str(window.prefit_tab.template_combo.itemData(index) or "") - for index in range(window.prefit_tab.template_combo.count()) - } + window.run_dream_bundle() + QApplication.processEvents() - assert window.project_setup_tab.show_deprecated_templates() - assert window.prefit_tab.show_deprecated_templates() - assert "template_likelihood_monosq" in project_template_names - assert "template_likelihood_monosq" in prefit_template_names - assert "template_pydream_poly_lma_hs_legacy" in project_template_names - assert "template_pydream_poly_lma_hs_legacy" in prefit_template_names + assert window.dream_tab.progress_bar.minimum() == 0 + assert window.dream_tab.progress_bar.maximum() == 0 + assert not window.dream_tab.run_button.isEnabled() + assert window._dream_progress_dialog is not None + assert window._dream_progress_dialog.isVisible() - window.prefit_tab.show_deprecated_templates_checkbox.setChecked(False) + deadline = time.time() + 2.0 + while ( + "DREAM sampler: initialization complete" + not in window.dream_tab.output_box.toPlainText() + and time.time() < deadline + ): + QApplication.processEvents() + time.sleep(0.02) - project_template_names = { - str(window.project_setup_tab.template_combo.itemData(index) or "") - for index in range(window.project_setup_tab.template_combo.count()) - } - prefit_template_names = { - str(window.prefit_tab.template_combo.itemData(index) or "") - for index in range(window.prefit_tab.template_combo.count()) + assert "DREAM Runtime Output" in window.dream_tab.output_box.toPlainText() + assert ( + "DREAM sampler: initialization complete" + in window.dream_tab.output_box.toPlainText() + ) + + window._dream_progress_dialog.close() + QApplication.processEvents() + assert not window._dream_progress_dialog.isVisible() + + deadline = time.time() + 5.0 + while window._dream_task_thread is not None and time.time() < deadline: + QApplication.processEvents() + time.sleep(0.02) + + assert window._dream_task_thread is None + assert window.dream_tab.run_button.isEnabled() + assert window.dream_tab.progress_bar.maximum() == 1 + assert window.dream_tab.progress_bar.value() == 1 + assert ( + "DREAM sampler: collecting posterior samples" + in window.dream_tab.output_box.toPlainText() + ) + assert "DREAM run complete" in window.dream_tab.output_box.toPlainText() + + +def test_dream_results_loader_normalizes_singleton_logp_axis(qapp, tmp_path): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + bundle = _write_minimal_dream_results(project_dir) + log_ps_path = bundle.run_dir / "dream_log_ps.npy" + log_ps = np.load(log_ps_path) + np.save(log_ps_path, np.asarray(log_ps, dtype=float)[..., np.newaxis]) + + loader = SAXSDreamResultsLoader(bundle.run_dir, burnin_percent=0) + summary = loader.get_summary() + + assert loader.log_ps.ndim == 2 + assert loader.log_ps.shape == (2, 4) + assert summary.posterior_sample_count == 8 + + +def test_dream_workflow_normalizes_stale_distribution_params_from_saved_map( + qapp, tmp_path +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + prefit = SAXSPrefitWorkflow(project_dir) + prefit.save_fit(prefit.parameter_entries) + workflow = SAXSDreamWorkflow(project_dir) + workflow.parameter_map_path.write_text( + json.dumps( + { + "entries": [ + { + "structure": "PbI2", + "motif": "motif_A", + "param_type": "SAXS", + "param": "w0", + "value": 0.6, + "vary": True, + "distribution": "norm", + "dist_params": { + "loc": 0.6, + "scale": 0.2, + "s": 0.1, + }, + } + ] + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + + entries = workflow.load_parameter_map() + + assert entries[0].distribution == "norm" + assert entries[0].dist_params == {"loc": 0.6, "scale": 0.2} + saved_project_map = json.loads( + workflow.parameter_map_path.read_text(encoding="utf-8") + ) + assert saved_project_map["entries"][0]["dist_params"] == { + "loc": 0.6, + "scale": 0.2, } - assert not window.project_setup_tab.show_deprecated_templates() - assert not window.prefit_tab.show_deprecated_templates() - assert window.project_setup_tab.selected_template_name() == ( - "template_pd_likelihood_monosq_decoupled" - ) - assert window.prefit_tab.selected_template_name() == ( - "template_pd_likelihood_monosq_decoupled" + bundle = workflow.create_runtime_bundle(entries=entries) + saved_map = json.loads( + bundle.parameter_map_path.read_text(encoding="utf-8") ) - assert "template_likelihood_monosq" not in project_template_names - assert "template_likelihood_monosq" not in prefit_template_names - assert "template_pd_likelihood_monosq_decoupled" in project_template_names - assert "template_pd_likelihood_monosq_decoupled" in prefit_template_names - window.close() + assert saved_map["entries"][0]["dist_params"] == { + "loc": 0.6, + "scale": 0.2, + } + metadata = json.loads(bundle.metadata_path.read_text(encoding="utf-8")) + assert metadata["active_parameter_entries"][0]["dist_params"] == { + "loc": 0.6, + "scale": 0.2, + } -def test_install_model_dialog_collects_model_inputs(qapp, tmp_path): +def test_dream_runtime_module_is_pickleable_for_parallel_execution( + qapp, tmp_path +): del qapp - candidate_template = tmp_path / "candidate_install_model.py" - candidate_template.write_text( - "import numpy as np\n" - "# model_lmfit: lmfit_model_profile\n" - "# model_pydream: log_likelihood_candidate\n" - "# inputs_lmfit: q, solvent_data, model_data, params\n" - "# inputs_pydream: q_values, experimental_intensities, " - "solvent_intensities, theoretical_intensities, params\n" - "# param_columns: Structure, Motif, Param, Value, Vary, Min, Max\n" - "# param: scale,1.0,True,0.0,10.0\n" - "def lmfit_model_profile(q, solvent_data, model_data, **params):\n" - " del q, solvent_data\n" - " return params['scale'] * np.asarray(model_data[0], dtype=float)\n" - "def log_likelihood_candidate(params):\n" - " del params\n" - " return -1.0\n", - encoding="utf-8", - ) - dialog = InstallModelDialog() - dialog.model_name_edit.setText("Dialog Candidate Model") - dialog.template_path_edit.setText(str(candidate_template)) - dialog.description_edit.setPlainText("Dialog-installed candidate.") + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + prefit = SAXSPrefitWorkflow(project_dir) + prefit.save_fit(prefit.parameter_entries) + workflow = SAXSDreamWorkflow(project_dir) + entries = workflow.create_default_parameter_map(persist=True) + bundle = workflow.create_runtime_bundle(entries=entries) - request = dialog.selected_request() + module_name, module, added_sys_path = workflow._load_runtime_module(bundle) + try: + pickled = pickle.dumps(module.active_log_likelihood) + finally: + workflow._unload_runtime_module( + module_name, + added_sys_path=added_sys_path, + ) - assert request == TemplateInstallRequest( - model_name="Dialog Candidate Model", - template_path=candidate_template.resolve(), - model_description="Dialog-installed candidate.", - ) + assert pickled -def test_install_model_template_installs_and_refreshes_template_lists( +def test_dream_runtime_module_saves_squeezed_log_ps( qapp, tmp_path, monkeypatch ): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - window = SAXSMainWindow(initial_project_dir=project_dir) - install_dir = tmp_path / "installed_templates" - request = TemplateInstallRequest( - model_name="Installed Candidate Model", - template_path=( - Path( - "tests/template_candidates/valid_installable_model.py" - ).resolve() - ), - model_description="Installed from the Project Setup dialog.", - ) - info_messages: list[tuple[str, str]] = [] - - monkeypatch.setattr( - window, - "_prompt_template_install_request", - lambda: request, - ) - monkeypatch.setattr( - "saxshell.saxs.ui.main_window.install_template_candidate", - lambda template_path, **kwargs: install_template_candidate( - template_path, - destination_dir=install_dir, - **kwargs, - ), - ) - monkeypatch.setattr( - "saxshell.saxs.ui.main_window.list_template_specs", - lambda **kwargs: list_template_specs(install_dir, **kwargs), - ) - monkeypatch.setattr( - "saxshell.saxs.ui.main_window.QMessageBox.information", - lambda _parent, title, message: info_messages.append((title, message)), - ) + prefit = SAXSPrefitWorkflow(project_dir) + prefit.save_fit(prefit.parameter_entries) + workflow = SAXSDreamWorkflow(project_dir) + entries = workflow.create_default_parameter_map(persist=True) + bundle = workflow.create_runtime_bundle(entries=entries) - window.project_setup_tab.install_model_button.click() + module_name, module, added_sys_path = workflow._load_runtime_module(bundle) + try: + monkeypatch.setattr( + module, + "run_dream", + lambda **kwargs: ( + [ + np.asarray( + [[0.6, 0.0, 0.05, 9.0, 0.0, 5e-4]], dtype=float + ), + np.asarray( + [[0.61, 0.0, 0.05, 9.1, 0.0, 5e-4]], dtype=float + ), + ], + [ + np.asarray([[-5.0]], dtype=float), + np.asarray([[-4.8]], dtype=float), + ], + ), + ) + module.run_sampler() + finally: + workflow._unload_runtime_module( + module_name, + added_sys_path=added_sys_path, + ) - assert (install_dir / "template_installed_candidate_model.py").is_file() - assert (install_dir / "template_installed_candidate_model.json").is_file() - assert any( - window.project_setup_tab.template_combo.itemText(index) - == "Installed Candidate Model" - for index in range(window.project_setup_tab.template_combo.count()) - ) - assert any( - window.prefit_tab.template_combo.itemText(index) - == "Installed Candidate Model" - for index in range(window.prefit_tab.template_combo.count()) - ) - assert info_messages - assert info_messages[0][0] == "Model installed" - assert "Installed Candidate Model" in info_messages[0][1] + saved_log_ps = np.load(bundle.run_dir / "dream_log_ps.npy") + assert saved_log_ps.shape == (2, 1) -def test_install_model_template_surfaces_validation_failures( +def test_dream_runtime_module_handles_array_shaped_crossover_values( qapp, tmp_path, monkeypatch ): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - window = SAXSMainWindow(initial_project_dir=project_dir) - install_dir = tmp_path / "installed_templates" - request = TemplateInstallRequest( - model_name="Broken Dream Callable Model", - template_path=( - Path( - "tests/template_candidates/fail_missing_dream_callable_model.py" - ).resolve() - ), - model_description="Expected to fail in the validation step.", - ) - errors: list[tuple[str, str]] = [] - - monkeypatch.setattr( - window, - "_prompt_template_install_request", - lambda: request, - ) - monkeypatch.setattr( - "saxshell.saxs.ui.main_window.install_template_candidate", - lambda template_path, **kwargs: install_template_candidate( - template_path, - destination_dir=install_dir, - **kwargs, - ), - ) - monkeypatch.setattr( - window, - "_show_error", - lambda title, message: errors.append((title, message)), - ) + prefit = SAXSPrefitWorkflow(project_dir) + prefit.save_fit(prefit.parameter_entries) + workflow = SAXSDreamWorkflow(project_dir) + entries = workflow.create_default_parameter_map(persist=True) + bundle = workflow.create_runtime_bundle(entries=entries) - window.project_setup_tab.install_model_button.click() + runtime_source = bundle.runtime_script_path.read_text(encoding="utf-8") + assert "def _build_runtime_mp_context()" in runtime_source + assert "mp_context=mp_context" in runtime_source - assert errors - assert errors[0][0] == "Install model failed" - assert "Missing callable log_likelihood_candidate" in errors[0][1] + module_name, module, added_sys_path = workflow._load_runtime_module(bundle) + try: + shared_vars = SimpleNamespace( + cross_probs=mp.Array("d", [0.5, 0.5]), + ncr_updates=mp.Array("d", [0.0, 0.0]), + current_positions=mp.Array("d", [1.0, 2.0, 1.2, 2.4]), + delta_m=mp.Array("d", [0.0, 0.0]), + gamma_level_probs=mp.Array("d", [1.0]), + ngamma_updates=mp.Array("d", [0.0]), + delta_m_gamma=mp.Array("d", [0.0]), + ) + monkeypatch.setattr(module, "Dream_shared_vars", shared_vars) + module._PyDreamDream.estimate_crossover_probabilities.__globals__[ + "Dream_shared_vars" + ] = shared_vars + module._PyDreamDream.estimate_gamma_level_probs.__globals__[ + "Dream_shared_vars" + ] = shared_vars + fake_dream = SimpleNamespace( + nCR=2, + nchains=2, + CR_values=np.asarray([0.5, 1.0], dtype=float), + CR_probabilities=np.asarray([0.5, 0.5], dtype=float), + ngamma=1, + gamma_level_values=np.asarray([1], dtype=int), + ) -def test_project_setup_empty_preview_message_is_wrapped(qapp): - del qapp - tab = ProjectSetupTab() - tab.draw_component_plot(None) + crossover_value = module._PyDreamDream.set_CR( + fake_dream, + np.asarray([1.0, 0.0], dtype=float), + np.asarray([0.5, 1.0], dtype=float), + ) + assert float(crossover_value) == pytest.approx(0.5) - preview_axis = tab.component_figure.axes[0] - text_labels = [text.get_text() for text in preview_axis.texts] + cross_probs = module._PyDreamDream.estimate_crossover_probabilities( + fake_dream, + 2, + np.asarray([1.0, 2.0], dtype=float), + np.asarray([1.1, 2.3], dtype=float), + np.asarray([0.5], dtype=float), + ) + assert np.asarray(cross_probs, dtype=float).shape == (2,) + assert list(shared_vars.ncr_updates[:]) == [1.0, 0.0] - assert any( - "Select experimental data and build SAXS" in label - and "averaged cluster profiles." in label - and "\n" in label - for label in text_labels - ) + gamma_probs = module._PyDreamDream.estimate_gamma_level_probs( + fake_dream, + 2, + np.asarray([1.0, 2.0], dtype=float), + np.asarray([1.1, 2.3], dtype=float), + np.asarray([1], dtype=int), + ) + assert np.asarray(gamma_probs, dtype=float).shape == (1,) + assert list(shared_vars.ngamma_updates[:]) == [1.0] + finally: + workflow._unload_runtime_module( + module_name, + added_sys_path=added_sys_path, + ) -def test_run_dream_requires_parameter_map_saved_in_session( +def test_save_dream_settings_creates_named_preset_and_restores_active_state( qapp, tmp_path, monkeypatch ): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - prefit = SAXSPrefitWorkflow(project_dir) - prefit.save_fit(prefit.parameter_entries) - workflow = SAXSDreamWorkflow(project_dir) - workflow.create_default_parameter_map() window = SAXSMainWindow(initial_project_dir=project_dir) - blinked = {"value": False} - monkeypatch.setattr( - window.dream_tab, - "blink_edit_priors_button", - lambda: blinked.update({"value": True}), - ) + window.dream_tab.chains_spin.setValue(6) + window.dream_tab.iterations_spin.setValue(2500) + window.dream_tab.posterior_filter_combo.setCurrentIndex(1) + window.dream_tab.posterior_top_percent_spin.setValue(12.5) + window.dream_tab.credible_interval_low_spin.setValue(10.0) + window.dream_tab.credible_interval_high_spin.setValue(90.0) + window.dream_tab.violin_sample_source_combo.setCurrentIndex(1) + monkeypatch.setattr( - "saxshell.saxs.ui.main_window.SAXSDreamWorkflow.run_bundle", - lambda self, bundle: (_ for _ in ()).throw( - AssertionError("run_bundle should not be called") - ), + "saxshell.saxs.ui.main_window.QInputDialog.getText", + lambda *args, **kwargs: ("Preset A", True), ) - window.run_dream_bundle() + window.save_dream_settings() - assert blinked["value"] - assert "Review the priors in Edit Priors" in ( - window.dream_tab.output_box.toPlainText() + preset_index = window.dream_tab.settings_preset_combo.findText("Preset A") + assert preset_index >= 0 + assert ( + window.dream_tab.settings_preset_combo.currentText() + == window.dream_tab.ACTIVE_SETTINGS_LABEL ) + window.dream_tab.chains_spin.setValue(11) + window.dream_tab.iterations_spin.setValue(3333) + window.dream_tab.posterior_filter_combo.setCurrentIndex(2) + window.dream_tab.posterior_top_n_spin.setValue(7) + window.dream_tab.violin_sample_source_combo.setCurrentIndex(0) -def test_run_dream_requires_written_runtime_bundle( - qapp, tmp_path, monkeypatch -): - del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) - prefit = SAXSPrefitWorkflow(project_dir) - prefit.save_fit(prefit.parameter_entries) - workflow = SAXSDreamWorkflow(project_dir) - entries = workflow.create_default_parameter_map(persist=True) - window = SAXSMainWindow(initial_project_dir=project_dir) - window._save_distribution_entries(entries) + window.dream_tab.settings_preset_combo.setCurrentIndex(preset_index) + QApplication.processEvents() - blinked = {"value": False} - errors: list[tuple[str, str]] = [] - monkeypatch.setattr( - window.dream_tab, - "blink_write_bundle_button", - lambda: blinked.update({"value": True}), - ) - monkeypatch.setattr( - window, - "_show_error", - lambda title, message: errors.append((title, message)), + assert window.dream_tab.chains_spin.value() == 6 + assert window.dream_tab.iterations_spin.value() == 2500 + assert ( + window.dream_tab.posterior_filter_combo.currentData() + == "top_percent_logp" ) - monkeypatch.setattr( - "saxshell.saxs.ui.main_window.SAXSDreamWorkflow.run_bundle", - lambda self, bundle: (_ for _ in ()).throw( - AssertionError("run_bundle should not be called") - ), + assert window.dream_tab.posterior_top_percent_spin.value() == 12.5 + assert window.dream_tab.credible_interval_low_spin.value() == 10.0 + assert window.dream_tab.credible_interval_high_spin.value() == 90.0 + assert ( + window.dream_tab.violin_sample_source_combo.currentData() + == "map_chain_only" ) - window.run_dream_bundle() + active_index = window.dream_tab.settings_preset_combo.findText( + window.dream_tab.ACTIVE_SETTINGS_LABEL + ) + window.dream_tab.settings_preset_combo.setCurrentIndex(active_index) + QApplication.processEvents() - assert blinked["value"] - assert errors == [ - ( - "Runtime Bundle not generated", - "Runtime Bundle not generated. Click Write Runtime Bundle before running DREAM.", - ) - ] - assert "Runtime Bundle not generated" in ( - window.dream_tab.output_box.toPlainText() + assert window.dream_tab.chains_spin.value() == 11 + assert window.dream_tab.iterations_spin.value() == 3333 + assert ( + window.dream_tab.posterior_filter_combo.currentData() == "top_n_logp" + ) + assert window.dream_tab.posterior_top_n_spin.value() == 7 + assert ( + window.dream_tab.violin_sample_source_combo.currentData() + == "filtered_posterior" ) -def test_preview_runtime_bundle_opens_latest_written_script( - qapp, tmp_path, monkeypatch +def test_dream_posterior_filter_controls_keep_default_thresholds_editable( + qapp, tmp_path ): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - prefit = SAXSPrefitWorkflow(project_dir) - prefit.save_fit(prefit.parameter_entries) - workflow = SAXSDreamWorkflow(project_dir) - entries = workflow.create_default_parameter_map(persist=True) - bundle = workflow.create_runtime_bundle(entries=entries) window = SAXSMainWindow(initial_project_dir=project_dir) - window._last_written_dream_bundle = bundle - opener = RuntimeBundleOpener( - label="Fake Editor", - stored_value="/Applications/FakeEditor.app", - launch_target="/Applications/FakeEditor.app", - launch_mode="mac_app", - ) - launched: dict[str, object] = {} + window.dream_tab.posterior_filter_combo.setCurrentIndex(1) + QApplication.processEvents() - monkeypatch.setattr( - window, - "_available_runtime_bundle_openers", - lambda: [opener], - ) - monkeypatch.setattr( - "saxshell.saxs.ui.main_window.QInputDialog.getItem", - lambda *args, **kwargs: ("Fake Editor", True), - ) - monkeypatch.setattr( - window, - "_launch_runtime_bundle_with_opener", - lambda script_path, selected_opener: launched.update( - { - "path": str(script_path), - "label": selected_opener.label, - "stored_value": selected_opener.stored_value, - } - ), - ) + assert window.dream_tab.posterior_top_percent_spin.isEnabled() + assert window.dream_tab.posterior_top_n_spin.isEnabled() - window.preview_dream_runtime_bundle() + window.dream_tab.posterior_filter_combo.setCurrentIndex(2) + QApplication.processEvents() - assert launched["path"] == str(bundle.runtime_script_path) - assert launched["label"] == "Fake Editor" - assert "Opened DREAM runtime bundle preview" in ( - window.dream_tab.output_box.toPlainText() - ) - reloaded_settings = SAXSProjectManager().load_project(project_dir) - assert reloaded_settings.runtime_bundle_opener == opener.stored_value + assert window.dream_tab.posterior_top_percent_spin.isEnabled() + assert window.dream_tab.posterior_top_n_spin.isEnabled() -def test_preview_runtime_bundle_reuses_saved_project_opener( - qapp, tmp_path, monkeypatch -): +def test_load_latest_dream_results_updates_both_plot_panels(qapp, tmp_path): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - prefit = SAXSPrefitWorkflow(project_dir) - prefit.save_fit(prefit.parameter_entries) - workflow = SAXSDreamWorkflow(project_dir) - entries = workflow.create_default_parameter_map(persist=True) - bundle = workflow.create_runtime_bundle(entries=entries) - - manager = SAXSProjectManager() - settings = manager.load_project(project_dir) - settings.runtime_bundle_opener = "/Applications/SavedEditor.app" - manager.save_project(settings) - + _write_minimal_dream_results(project_dir) window = SAXSMainWindow(initial_project_dir=project_dir) - window._last_written_dream_bundle = bundle - launched: dict[str, object] = {} + window.dream_tab.bestfit_method_combo.setCurrentIndex(1) + window.dream_tab.violin_mode_combo.setCurrentIndex(2) + window.dream_tab.posterior_filter_combo.setCurrentIndex(2) + window.dream_tab.posterior_top_n_spin.setValue(1) + window.dream_tab.credible_interval_low_spin.setValue(5.0) + window.dream_tab.credible_interval_high_spin.setValue(95.0) + window.dream_tab.violin_sample_source_combo.setCurrentIndex(1) - monkeypatch.setattr( - window, - "_available_runtime_bundle_openers", - lambda: [ - RuntimeBundleOpener( - label="Saved Editor", - stored_value="/Applications/SavedEditor.app", - launch_target="/Applications/SavedEditor.app", - launch_mode="mac_app", - ) - ], + window.load_latest_results() + + assert "Best-fit method: chain_mean" in ( + window.dream_tab.output_box.toPlainText() ) - monkeypatch.setattr( - "saxshell.saxs.ui.main_window.QInputDialog.getItem", - lambda *args, **kwargs: pytest.fail( - "The opener chooser should not appear for a saved project opener." - ), + assert "Posterior filter: top_n_logp" in ( + window.dream_tab.output_box.toPlainText() ) - monkeypatch.setattr( - window, - "_launch_runtime_bundle_with_opener", - lambda script_path, selected_opener: launched.update( - { - "path": str(script_path), - "label": selected_opener.label, - "stored_value": selected_opener.stored_value, - } - ), + assert "Posterior samples kept: 1" in ( + window.dream_tab.output_box.toPlainText() ) - - window.preview_dream_runtime_bundle() - - assert launched["path"] == str(bundle.runtime_script_path) - assert launched["stored_value"] == "/Applications/SavedEditor.app" + assert "Violin data mode: weights_only" in ( + window.dream_tab.output_box.toPlainText() + ) + assert "Violin sample source: map_chain_only" in ( + window.dream_tab.output_box.toPlainText() + ) + assert "p5=" in window.dream_tab.output_box.toPlainText() + assert "p95=" in window.dream_tab.output_box.toPlainText() + assert ( + window.dream_tab.model_figure.axes[0] + .get_title() + .startswith("DREAM refinement:") + ) + metric_text = "\n".join( + text.get_text() for text in window.dream_tab.model_figure.axes[0].texts + ) + assert "RMSE:" in metric_text + assert "Mean |res|:" in metric_text + assert "R²:" in metric_text + assert ( + window.dream_tab.violin_figure.axes[0].get_title() + == "Posterior parameter distributions" + ) + tick_labels = [ + label.get_text() + for label in window.dream_tab.violin_figure.axes[0].get_xticklabels() + ] + assert "w0 (A)" in tick_labels -def test_run_dream_shows_progress_and_popup_can_be_closed( - qapp, tmp_path, monkeypatch +def test_dream_analysis_saved_run_dropdown_loads_selected_run_state( + qapp, + tmp_path, ): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) prefit = SAXSPrefitWorkflow(project_dir) prefit.save_fit(prefit.parameter_entries) workflow = SAXSDreamWorkflow(project_dir) - entries = workflow.create_default_parameter_map(persist=True) - window = SAXSMainWindow(initial_project_dir=project_dir) - window._save_distribution_entries(entries) - window.write_dream_bundle() - - def _fake_run_bundle( - self, - bundle, - *, - output_callback=None, - output_interval_seconds=None, - ): - del output_interval_seconds - if output_callback is not None: - output_callback("DREAM sampler: initialization complete") - time.sleep(0.15) - if output_callback is not None: - output_callback("DREAM sampler: collecting posterior samples") - metadata = json.loads(bundle.metadata_path.read_text(encoding="utf-8")) - active_count = len(metadata["active_parameter_entries"]) - active_values = np.asarray( - [ - float(entry["value"]) - for entry in metadata["active_parameter_entries"] - ], - dtype=float, - ) - sampled_params = [] - log_ps = [] - for chain_index in range(2): - chain_samples = [] - chain_logps = [] - for step_index in range(4): - adjustment = (chain_index + 1) * (step_index + 1) * 0.01 - chain_samples.append(active_values[:active_count] + adjustment) - chain_logps.append(-5.0 + chain_index + step_index * 0.5) - sampled_params.append(chain_samples) - log_ps.append(chain_logps) - np.save( - bundle.run_dir / "dream_sampled_params.npy", - np.asarray(sampled_params, dtype=float), - ) - np.save( - bundle.run_dir / "dream_log_ps.npy", - np.asarray(log_ps, dtype=float)[..., np.newaxis], - ) - return { - "sampled_params_path": str( - bundle.run_dir / "dream_sampled_params.npy" - ), - "log_ps_path": str(bundle.run_dir / "dream_log_ps.npy"), - } - monkeypatch.setattr( - "saxshell.saxs.ui.main_window.SAXSDreamWorkflow.run_bundle", - _fake_run_bundle, + older_entries = workflow.create_default_parameter_map() + older_entries[0] = DreamParameterEntry( + structure=older_entries[0].structure, + motif=older_entries[0].motif, + param_type=older_entries[0].param_type, + param=older_entries[0].param, + value=0.11, + vary=older_entries[0].vary, + distribution=older_entries[0].distribution, + dist_params=dict(older_entries[0].dist_params), + smart_preset_status=older_entries[0].smart_preset_status, + ) + older_settings = DreamRunSettings( + nchains=3, + niterations=1234, + burnin_percent=7, + model_name="older_model", + run_label="older", + ) + older_bundle = _write_minimal_dream_results( + project_dir, + settings=older_settings, + entries=older_entries, ) - window.run_dream_bundle() - QApplication.processEvents() - - assert window.dream_tab.progress_bar.minimum() == 0 - assert window.dream_tab.progress_bar.maximum() == 0 - assert not window.dream_tab.run_button.isEnabled() - assert window._dream_progress_dialog is not None - assert window._dream_progress_dialog.isVisible() + newer_entries = workflow.create_default_parameter_map() + newer_entries[0] = DreamParameterEntry( + structure=newer_entries[0].structure, + motif=newer_entries[0].motif, + param_type=newer_entries[0].param_type, + param=newer_entries[0].param, + value=0.77, + vary=newer_entries[0].vary, + distribution=newer_entries[0].distribution, + dist_params=dict(newer_entries[0].dist_params), + smart_preset_status=newer_entries[0].smart_preset_status, + ) + newer_settings = DreamRunSettings( + nchains=8, + niterations=4321, + burnin_percent=22, + model_name="newer_model", + run_label="newer", + ) + newer_bundle = _write_minimal_dream_results( + project_dir, + settings=newer_settings, + entries=newer_entries, + ) - deadline = time.time() + 2.0 - while ( - "DREAM sampler: initialization complete" - not in window.dream_tab.output_box.toPlainText() - and time.time() < deadline - ): - QApplication.processEvents() - time.sleep(0.02) + window = SAXSMainWindow(initial_project_dir=project_dir) - assert "DREAM Runtime Output" in window.dream_tab.output_box.toPlainText() + assert window.dream_tab.saved_runs_combo.count() == 2 assert ( - "DREAM sampler: initialization complete" - in window.dream_tab.output_box.toPlainText() + Path(window.dream_tab.saved_runs_combo.currentData()).resolve() + == newer_bundle.run_dir.resolve() ) - window._dream_progress_dialog.close() + older_index = window.dream_tab.saved_runs_combo.findData( + str(older_bundle.run_dir) + ) + assert older_index >= 0 + window.dream_tab.saved_runs_combo.setCurrentIndex(older_index) QApplication.processEvents() - assert not window._dream_progress_dialog.isVisible() - deadline = time.time() + 5.0 - while window._dream_task_thread is not None and time.time() < deadline: - QApplication.processEvents() - time.sleep(0.02) + window.load_selected_results() - assert window._dream_task_thread is None - assert window.dream_tab.run_button.isEnabled() - assert window.dream_tab.progress_bar.maximum() == 1 - assert window.dream_tab.progress_bar.value() == 1 + loaded_settings = load_dream_settings(older_bundle.settings_path) + assert window._last_results_loader is not None assert ( - "DREAM sampler: collecting posterior samples" - in window.dream_tab.output_box.toPlainText() + window._last_results_loader.run_dir == older_bundle.run_dir.resolve() + ) + assert window.dream_tab.chains_spin.value() == loaded_settings.nchains + assert ( + window.dream_tab.iterations_spin.value() == loaded_settings.niterations + ) + assert ( + window.dream_tab.burnin_spin.value() == loaded_settings.burnin_percent + ) + assert float( + window.dream_tab.parameter_map_table.item(0, 4).text() + ) == pytest.approx(older_entries[0].value) + assert ( + str(older_bundle.run_dir) in window.dream_tab.output_box.toPlainText() ) - assert "DREAM run complete" in window.dream_tab.output_box.toPlainText() -def test_dream_results_loader_normalizes_singleton_logp_axis(qapp, tmp_path): - del qapp +def test_dream_model_metrics_box_updates_with_bestfit_method(qapp, tmp_path): project_dir, _paths = _build_minimal_saxs_project(tmp_path) - bundle = _write_minimal_dream_results(project_dir) - log_ps_path = bundle.run_dir / "dream_log_ps.npy" - log_ps = np.load(log_ps_path) - np.save(log_ps_path, np.asarray(log_ps, dtype=float)[..., np.newaxis]) + _write_minimal_dream_results(project_dir) + window = SAXSMainWindow(initial_project_dir=project_dir) - loader = SAXSDreamResultsLoader(bundle.run_dir, burnin_percent=0) - summary = loader.get_summary() + window.load_latest_results() + axis = window.dream_tab.model_figure.axes[0] + first_metrics = "\n".join(text.get_text() for text in axis.texts) - assert loader.log_ps.ndim == 2 - assert loader.log_ps.shape == (2, 4) - assert summary.posterior_sample_count == 8 + window.dream_tab.bestfit_method_combo.setCurrentIndex(2) + _wait_for_dream_refresh(qapp) + + axis = window.dream_tab.model_figure.axes[0] + second_metrics = "\n".join(text.get_text() for text in axis.texts) + assert "RMSE:" in second_metrics + assert "Mean |res|:" in second_metrics + assert "R²:" in second_metrics + assert first_metrics != second_metrics -def test_dream_workflow_normalizes_stale_distribution_params_from_saved_map( - qapp, tmp_path -): + +def test_dream_model_plot_includes_residual_subplot(qapp, tmp_path): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - prefit = SAXSPrefitWorkflow(project_dir) - prefit.save_fit(prefit.parameter_entries) - workflow = SAXSDreamWorkflow(project_dir) - workflow.parameter_map_path.write_text( - json.dumps( - { - "entries": [ - { - "structure": "PbI2", - "motif": "motif_A", - "param_type": "SAXS", - "param": "w0", - "value": 0.6, - "vary": True, - "distribution": "norm", - "dist_params": { - "loc": 0.6, - "scale": 0.2, - "s": 0.1, - }, - } - ] - }, - indent=2, - ) - + "\n", - encoding="utf-8", - ) + _write_minimal_dream_results(project_dir) + window = SAXSMainWindow(initial_project_dir=project_dir) - entries = workflow.load_parameter_map() + window.load_latest_results() - assert entries[0].distribution == "norm" - assert entries[0].dist_params == {"loc": 0.6, "scale": 0.2} - saved_project_map = json.loads( - workflow.parameter_map_path.read_text(encoding="utf-8") - ) - assert saved_project_map["entries"][0]["dist_params"] == { - "loc": 0.6, - "scale": 0.2, - } + assert len(window.dream_tab.model_figure.axes) == 2 + top_axis = window.dream_tab.model_figure.axes[0] + residual_axis = window.dream_tab.model_figure.axes[1] + assert top_axis.get_title().startswith("DREAM refinement:") + assert residual_axis.get_ylabel() == "Residual" + assert residual_axis.get_xlabel() == "q (Å⁻¹)" + assert residual_axis.get_xscale() == top_axis.get_xscale() - bundle = workflow.create_runtime_bundle(entries=entries) - saved_map = json.loads( - bundle.parameter_map_path.read_text(encoding="utf-8") + residual_line = residual_axis.get_lines()[-1] + plot_data = window.dream_tab._current_model_plot_data + assert plot_data is not None + expected = np.asarray( + plot_data.model_intensities - plot_data.experimental_intensities, + dtype=float, ) - assert saved_map["entries"][0]["dist_params"] == { - "loc": 0.6, - "scale": 0.2, - } - metadata = json.loads(bundle.metadata_path.read_text(encoding="utf-8")) - assert metadata["active_parameter_entries"][0]["dist_params"] == { - "loc": 0.6, - "scale": 0.2, - } + assert np.allclose( + np.asarray(residual_line.get_ydata(), dtype=float), + expected, + ) + window.close() -def test_dream_runtime_module_is_pickleable_for_parallel_execution( - qapp, tmp_path +def test_dream_model_plot_redraw_on_log_x_avoids_nonpositive_xlim_warning( + qapp, + tmp_path, ): - del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - prefit = SAXSPrefitWorkflow(project_dir) - prefit.save_fit(prefit.parameter_entries) - workflow = SAXSDreamWorkflow(project_dir) - entries = workflow.create_default_parameter_map(persist=True) - bundle = workflow.create_runtime_bundle(entries=entries) + _write_minimal_dream_results(project_dir) + window = SAXSMainWindow(initial_project_dir=project_dir) - module_name, module, added_sys_path = workflow._load_runtime_module(bundle) - try: - pickled = pickle.dumps(module.active_log_likelihood) - finally: - workflow._unload_runtime_module( - module_name, - added_sys_path=added_sys_path, - ) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + window.load_latest_results() + window.dream_tab.model_log_x_checkbox.setChecked(True) + qapp.processEvents() + window.dream_tab.bestfit_method_combo.setCurrentIndex(1) + _wait_for_dream_refresh(qapp) - assert pickled + warning_messages = [str(item.message) for item in caught] + assert not any( + "Attempt to set non-positive xlim on a log-scaled axis" in message + for message in warning_messages + ) + window.close() -def test_dream_runtime_module_saves_squeezed_log_ps( - qapp, tmp_path, monkeypatch -): +def test_prefit_model_metrics_box_updates_with_model_changes(qapp, tmp_path): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - prefit = SAXSPrefitWorkflow(project_dir) - prefit.save_fit(prefit.parameter_entries) - workflow = SAXSDreamWorkflow(project_dir) - entries = workflow.create_default_parameter_map(persist=True) - bundle = workflow.create_runtime_bundle(entries=entries) + window = SAXSMainWindow(initial_project_dir=project_dir) - module_name, module, added_sys_path = workflow._load_runtime_module(bundle) - try: - monkeypatch.setattr( - module, - "run_dream", - lambda **kwargs: ( - [ - np.asarray( - [[0.6, 0.0, 0.05, 9.0, 0.0, 5e-4]], dtype=float - ), - np.asarray( - [[0.61, 0.0, 0.05, 9.1, 0.0, 5e-4]], dtype=float - ), - ], - [ - np.asarray([[-5.0]], dtype=float), - np.asarray([[-4.8]], dtype=float), - ], - ), - ) - module.run_sampler() - finally: - workflow._unload_runtime_module( - module_name, - added_sys_path=added_sys_path, - ) + axis = window.prefit_tab.figure.axes[0] + first_metrics = "\n".join(text.get_text() for text in axis.texts) - saved_log_ps = np.load(bundle.run_dir / "dream_log_ps.npy") - assert saved_log_ps.shape == (2, 1) + assert "RMSE:" in first_metrics + assert "Mean |res|:" in first_metrics + assert "R²:" in first_metrics + window.prefit_tab.set_parameter_row("scale", value=1e-3) + window.update_prefit_model() -def test_dream_runtime_module_handles_array_shaped_crossover_values( - qapp, tmp_path, monkeypatch -): - del qapp + axis = window.prefit_tab.figure.axes[0] + second_metrics = "\n".join(text.get_text() for text in axis.texts) + + assert "RMSE:" in second_metrics + assert "Mean |res|:" in second_metrics + assert "R²:" in second_metrics + assert first_metrics != second_metrics + window.close() + + +def test_dream_violin_scale_modes_and_palette_controls(qapp, tmp_path): project_dir, _paths = _build_minimal_saxs_project(tmp_path) - prefit = SAXSPrefitWorkflow(project_dir) - prefit.save_fit(prefit.parameter_entries) - workflow = SAXSDreamWorkflow(project_dir) - entries = workflow.create_default_parameter_map(persist=True) - bundle = workflow.create_runtime_bundle(entries=entries) + _write_minimal_dream_results(project_dir) + window = SAXSMainWindow(initial_project_dir=project_dir) - runtime_source = bundle.runtime_script_path.read_text(encoding="utf-8") - assert "def _build_runtime_mp_context()" in runtime_source - assert "mp_context=mp_context" in runtime_source + palette_index = window.dream_tab.violin_palette_combo.findData("plasma") + window.load_latest_results() + window.dream_tab.violin_value_scale_combo.setCurrentIndex(1) + window.dream_tab.violin_palette_combo.setCurrentIndex(palette_index) + window.dream_tab._configure_plot_color_button( + window.dream_tab.violin_point_color_button, + "tab:blue", + label="Point", + ) + window.dream_tab.visualization_settings_changed.emit() + _wait_for_dream_refresh(qapp) - module_name, module, added_sys_path = workflow._load_runtime_module(bundle) - try: - shared_vars = SimpleNamespace( - cross_probs=mp.Array("d", [0.5, 0.5]), - ncr_updates=mp.Array("d", [0.0, 0.0]), - current_positions=mp.Array("d", [1.0, 2.0, 1.2, 2.4]), - delta_m=mp.Array("d", [0.0, 0.0]), - gamma_level_probs=mp.Array("d", [1.0]), - ngamma_updates=mp.Array("d", [0.0]), - delta_m_gamma=mp.Array("d", [0.0]), - ) - monkeypatch.setattr(module, "Dream_shared_vars", shared_vars) - module._PyDreamDream.estimate_crossover_probabilities.__globals__[ - "Dream_shared_vars" - ] = shared_vars - module._PyDreamDream.estimate_gamma_level_probs.__globals__[ - "Dream_shared_vars" - ] = shared_vars + axis = window.dream_tab.violin_figure.axes[0] + tick_labels = [label.get_text() for label in axis.get_xticklabels()] + assert tick_labels == ["w0 (A)"] + assert axis.get_ylabel() == "Weight fraction" + assert axis.get_title() == "Posterior weight distributions" + assert axis.get_ylim() == pytest.approx((0.0, 1.0)) + body = next( + collection + for collection in axis.collections + if isinstance(collection, PolyCollection) + ) + assert to_hex(body.get_facecolor()[0], keep_alpha=False) == to_hex( + colormaps["plasma"](0.72), + keep_alpha=False, + ) + assert to_hex( + axis.collections[-1].get_facecolor()[0], keep_alpha=False + ) == to_hex( + "tab:blue", + keep_alpha=False, + ) - fake_dream = SimpleNamespace( - nCR=2, - nchains=2, - CR_values=np.asarray([0.5, 1.0], dtype=float), - CR_probabilities=np.asarray([0.5, 0.5], dtype=float), - ngamma=1, - gamma_level_values=np.asarray([1], dtype=int), - ) + window.dream_tab.violin_value_scale_combo.setCurrentIndex(2) + _wait_for_dream_refresh(qapp) - crossover_value = module._PyDreamDream.set_CR( - fake_dream, - np.asarray([1.0, 0.0], dtype=float), - np.asarray([0.5, 1.0], dtype=float), - ) - assert float(crossover_value) == pytest.approx(0.5) + axis = window.dream_tab.violin_figure.axes[0] + assert axis.get_ylabel() == "Normalized parameter value" + assert axis.get_title() == "Posterior parameter distributions (normalized)" + assert axis.get_ylim() == pytest.approx((0.0, 1.0)) + normalized_labels = [ + label.get_text() + for label in axis.get_xticklabels() + if label.get_text() + ] + assert "w0 (A)" in normalized_labels + assert "solv_w" in normalized_labels - cross_probs = module._PyDreamDream.estimate_crossover_probabilities( - fake_dream, - 2, - np.asarray([1.0, 2.0], dtype=float), - np.asarray([1.1, 2.3], dtype=float), - np.asarray([0.5], dtype=float), - ) - assert np.asarray(cross_probs, dtype=float).shape == (2,) - assert list(shared_vars.ncr_updates[:]) == [1.0, 0.0] - gamma_probs = module._PyDreamDream.estimate_gamma_level_probs( - fake_dream, - 2, - np.asarray([1.0, 2.0], dtype=float), - np.asarray([1.1, 2.3], dtype=float), - np.asarray([1], dtype=int), - ) - assert np.asarray(gamma_probs, dtype=float).shape == (1,) - assert list(shared_vars.ngamma_updates[:]) == [1.0] - finally: - workflow._unload_runtime_module( - module_name, - added_sys_path=added_sys_path, +def test_dream_violin_custom_color_controls_apply_to_plot( + qapp, tmp_path, monkeypatch +): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + _write_minimal_dream_results(project_dir) + window = SAXSMainWindow(initial_project_dir=project_dir) + + chosen_colors = iter( + [ + QColor("#123456"), + QColor("#fedcba"), + QColor("#654321"), + QColor("#abcdef"), + QColor("#111111"), + ] + ) + monkeypatch.setattr( + "saxshell.saxs.ui.dream_tab.QColorDialog.getColor", + lambda *args, **kwargs: next(chosen_colors), + ) + + palette_index = window.dream_tab.violin_palette_combo.findData( + "custom_solid" + ) + window.load_latest_results() + window.dream_tab.violin_palette_combo.setCurrentIndex(palette_index) + window.dream_tab._choose_violin_custom_color() + window.dream_tab._choose_violin_point_color() + window.dream_tab._choose_interval_color() + window.dream_tab._choose_median_color() + window.dream_tab._choose_outline_color() + window.dream_tab.violin_outline_width_spin.setValue(1.7) + _wait_for_dream_refresh(qapp) + + axis = window.dream_tab.violin_figure.axes[0] + body = next( + collection + for collection in axis.collections + if isinstance(collection, PolyCollection) + ) + assert to_hex(body.get_facecolor()[0], keep_alpha=False) == "#123456" + assert to_hex(body.get_edgecolor()[0], keep_alpha=False) == "#111111" + assert body.get_linewidths()[0] == pytest.approx(1.7) + assert ( + to_hex( + axis.collections[-1].get_facecolor()[0], + keep_alpha=False, ) + == "#fedcba" + ) + line_colors = [ + to_hex(color, keep_alpha=False) + for collection in axis.collections + if isinstance(collection, LineCollection) + for color in collection.get_colors() + ] + assert "#654321" in line_colors + assert "#abcdef" in line_colors -def test_save_dream_settings_creates_named_preset_and_restores_active_state( - qapp, tmp_path, monkeypatch +def test_dream_violin_custom_color_picker_switches_palette_and_updates_plot( + qapp, + tmp_path, + monkeypatch, ): - del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) + _write_minimal_dream_results(project_dir) window = SAXSMainWindow(initial_project_dir=project_dir) - window.dream_tab.chains_spin.setValue(6) - window.dream_tab.iterations_spin.setValue(2500) - window.dream_tab.posterior_filter_combo.setCurrentIndex(1) - window.dream_tab.posterior_top_percent_spin.setValue(12.5) - window.dream_tab.credible_interval_low_spin.setValue(10.0) - window.dream_tab.credible_interval_high_spin.setValue(90.0) - window.dream_tab.violin_sample_source_combo.setCurrentIndex(1) - monkeypatch.setattr( - "saxshell.saxs.ui.main_window.QInputDialog.getText", - lambda *args, **kwargs: ("Preset A", True), - ) - - window.save_dream_settings() - - preset_index = window.dream_tab.settings_preset_combo.findText("Preset A") - assert preset_index >= 0 - assert ( - window.dream_tab.settings_preset_combo.currentText() - == window.dream_tab.ACTIVE_SETTINGS_LABEL + "saxshell.saxs.ui.dream_tab.QColorDialog.getColor", + lambda *args, **kwargs: QColor("#224466"), ) - window.dream_tab.chains_spin.setValue(11) - window.dream_tab.iterations_spin.setValue(3333) - window.dream_tab.posterior_filter_combo.setCurrentIndex(2) - window.dream_tab.posterior_top_n_spin.setValue(7) - window.dream_tab.violin_sample_source_combo.setCurrentIndex(0) - - window.dream_tab.settings_preset_combo.setCurrentIndex(preset_index) - QApplication.processEvents() - - assert window.dream_tab.chains_spin.value() == 6 - assert window.dream_tab.iterations_spin.value() == 2500 - assert ( - window.dream_tab.posterior_filter_combo.currentData() - == "top_percent_logp" - ) - assert window.dream_tab.posterior_top_percent_spin.value() == 12.5 - assert window.dream_tab.credible_interval_low_spin.value() == 10.0 - assert window.dream_tab.credible_interval_high_spin.value() == 90.0 - assert ( - window.dream_tab.violin_sample_source_combo.currentData() - == "map_chain_only" - ) + window.load_latest_results() + assert window.dream_tab.violin_palette_combo.currentData() == "Blues" - active_index = window.dream_tab.settings_preset_combo.findText( - window.dream_tab.ACTIVE_SETTINGS_LABEL - ) - window.dream_tab.settings_preset_combo.setCurrentIndex(active_index) - QApplication.processEvents() + window.dream_tab.violin_custom_color_button.click() + _wait_for_dream_refresh(qapp) - assert window.dream_tab.chains_spin.value() == 11 - assert window.dream_tab.iterations_spin.value() == 3333 assert ( - window.dream_tab.posterior_filter_combo.currentData() == "top_n_logp" + window.dream_tab.violin_palette_combo.currentData() == "custom_solid" ) - assert window.dream_tab.posterior_top_n_spin.value() == 7 - assert ( - window.dream_tab.violin_sample_source_combo.currentData() - == "filtered_posterior" + assert window.dream_tab.selected_violin_custom_color() == "#224466" + axis = window.dream_tab.violin_figure.axes[0] + body = next( + collection + for collection in axis.collections + if isinstance(collection, PolyCollection) ) + assert to_hex(body.get_facecolor()[0], keep_alpha=False) == "#224466" -def test_dream_posterior_filter_controls_keep_default_thresholds_editable( - qapp, tmp_path +def test_dream_default_violin_palette_starts_with_higher_contrast_color( + qapp, ): del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) - window = SAXSMainWindow(initial_project_dir=project_dir) - - window.dream_tab.posterior_filter_combo.setCurrentIndex(1) - QApplication.processEvents() - - assert window.dream_tab.posterior_top_percent_spin.isEnabled() - assert window.dream_tab.posterior_top_n_spin.isEnabled() + window = SAXSMainWindow() - window.dream_tab.posterior_filter_combo.setCurrentIndex(2) - QApplication.processEvents() + colors = window.dream_tab._violin_body_colors(4) + first_color = colors[0] + expected = colormaps.get_cmap("Blues")(0.35) - assert window.dream_tab.posterior_top_percent_spin.isEnabled() - assert window.dream_tab.posterior_top_n_spin.isEnabled() + assert to_hex(first_color, keep_alpha=False) == to_hex( + expected, + keep_alpha=False, + ) -def test_load_latest_dream_results_updates_both_plot_panels(qapp, tmp_path): +def test_dream_tab_limits_violin_display_sample_count(qapp): del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) - _write_minimal_dream_results(project_dir) - window = SAXSMainWindow(initial_project_dir=project_dir) - window.dream_tab.bestfit_method_combo.setCurrentIndex(1) - window.dream_tab.violin_mode_combo.setCurrentIndex(2) - window.dream_tab.posterior_filter_combo.setCurrentIndex(2) - window.dream_tab.posterior_top_n_spin.setValue(1) - window.dream_tab.credible_interval_low_spin.setValue(5.0) - window.dream_tab.credible_interval_high_spin.setValue(95.0) - window.dream_tab.violin_sample_source_combo.setCurrentIndex(1) + samples = np.arange(12_000, dtype=float).reshape(6_000, 2) - window.load_latest_results() + limited = DreamTab._display_violin_samples(samples) - assert "Best-fit method: chain_mean" in ( - window.dream_tab.output_box.toPlainText() - ) - assert "Posterior filter: top_n_logp" in ( - window.dream_tab.output_box.toPlainText() - ) - assert "Posterior samples kept: 1" in ( - window.dream_tab.output_box.toPlainText() - ) - assert "Violin data mode: weights_only" in ( - window.dream_tab.output_box.toPlainText() - ) - assert "Violin sample source: map_chain_only" in ( - window.dream_tab.output_box.toPlainText() - ) - assert "p5=" in window.dream_tab.output_box.toPlainText() - assert "p95=" in window.dream_tab.output_box.toPlainText() - assert ( - window.dream_tab.model_figure.axes[0] - .get_title() - .startswith("DREAM refinement:") - ) - metric_text = "\n".join( - text.get_text() for text in window.dream_tab.model_figure.axes[0].texts - ) - assert "RMSE:" in metric_text - assert "Mean |res|:" in metric_text - assert "R²:" in metric_text - assert ( - window.dream_tab.violin_figure.axes[0].get_title() - == "Posterior parameter distributions" - ) - tick_labels = [ - label.get_text() - for label in window.dream_tab.violin_figure.axes[0].get_xticklabels() - ] - assert "w0 (A)" in tick_labels + assert limited.shape == (DreamTab.MAX_VIOLIN_PLOT_SAMPLES, 2) + assert np.allclose(limited[0], samples[0]) + assert np.allclose(limited[-1], samples[-1]) -def test_dream_analysis_saved_run_dropdown_loads_selected_run_state( - qapp, +def test_dream_results_loader_can_order_weight_violin_labels_by_structure( tmp_path, ): - del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) - prefit = SAXSPrefitWorkflow(project_dir) - prefit.save_fit(prefit.parameter_entries) - workflow = SAXSDreamWorkflow(project_dir) + run_dir = _write_weight_order_dream_results(tmp_path) + loader = SAXSDreamResultsLoader(run_dir, burnin_percent=0) - older_entries = workflow.create_default_parameter_map() - older_entries[0] = DreamParameterEntry( - structure=older_entries[0].structure, - motif=older_entries[0].motif, - param_type=older_entries[0].param_type, - param=older_entries[0].param, - value=0.11, - vary=older_entries[0].vary, - distribution=older_entries[0].distribution, - dist_params=dict(older_entries[0].dist_params), - smart_preset_status=older_entries[0].smart_preset_status, - ) - older_settings = DreamRunSettings( - nchains=3, - niterations=1234, - burnin_percent=7, - model_name="older_model", - run_label="older", + weight_index_plot = loader.build_violin_data( + mode="weights_only", + weight_order="weight_index", ) - older_bundle = _write_minimal_dream_results( - project_dir, - settings=older_settings, - entries=older_entries, + structure_order_plot = loader.build_violin_data( + mode="weights_only", + weight_order="structure_order", ) - newer_entries = workflow.create_default_parameter_map() - newer_entries[0] = DreamParameterEntry( - structure=newer_entries[0].structure, - motif=newer_entries[0].motif, - param_type=newer_entries[0].param_type, - param=newer_entries[0].param, - value=0.77, - vary=newer_entries[0].vary, - distribution=newer_entries[0].distribution, - dist_params=dict(newer_entries[0].dist_params), - smart_preset_status=newer_entries[0].smart_preset_status, - ) - newer_settings = DreamRunSettings( - nchains=8, - niterations=4321, - burnin_percent=22, - model_name="newer_model", - run_label="newer", - ) - newer_bundle = _write_minimal_dream_results( - project_dir, - settings=newer_settings, - entries=newer_entries, - ) + assert weight_index_plot.parameter_names == ["w2", "w0", "w1"] + assert weight_index_plot.display_names == [ + "w2 (PbI2O)", + "w0 (I2)", + "w1 (Pb2)", + ] + assert structure_order_plot.parameter_names == ["w0", "w1", "w2"] + assert structure_order_plot.display_names == [ + "w0 (I2)", + "w1 (Pb2)", + "w2 (PbI2O)", + ] - window = SAXSMainWindow(initial_project_dir=project_dir) - assert window.dream_tab.saved_runs_combo.count() == 2 - assert ( - Path(window.dream_tab.saved_runs_combo.currentData()).resolve() - == newer_bundle.run_dir.resolve() - ) +def test_dream_results_loader_splits_radius_and_additional_violin_modes( + tmp_path, +): + run_dir = _write_violin_mode_split_dream_results(tmp_path) + loader = SAXSDreamResultsLoader(run_dir, burnin_percent=0) - older_index = window.dream_tab.saved_runs_combo.findData( - str(older_bundle.run_dir) + radius_plot = loader.build_violin_data(mode="effective_radii_only") + additional_plot = loader.build_violin_data( + mode="additional_parameters_only" ) - assert older_index >= 0 - window.dream_tab.saved_runs_combo.setCurrentIndex(older_index) - QApplication.processEvents() + fit_plot = loader.build_violin_data(mode="fit_parameters") + + assert radius_plot.parameter_names == [ + "r_eff_w0", + "a_eff_w1", + "b_eff_w1", + "c_eff_w1", + ] + assert additional_plot.parameter_names == [ + "scale", + "offset", + "phi_int", + ] + assert fit_plot.parameter_names == [ + "r_eff_w0", + "a_eff_w1", + "b_eff_w1", + "c_eff_w1", + "scale", + "offset", + "phi_int", + ] - window.load_selected_results() - loaded_settings = load_dream_settings(older_bundle.settings_path) - assert window._last_results_loader is not None - assert ( - window._last_results_loader.run_dir == older_bundle.run_dir.resolve() +def test_effective_dream_violin_mode_honors_new_value_scale_overrides(): + radius_settings = DreamRunSettings( + violin_parameter_mode="fit_parameters", + violin_value_scale_mode="effective_radii_only", ) - assert window.dream_tab.chains_spin.value() == loaded_settings.nchains - assert ( - window.dream_tab.iterations_spin.value() == loaded_settings.niterations + additional_settings = DreamRunSettings( + violin_parameter_mode="weights_only", + violin_value_scale_mode="additional_parameters_only", ) + assert ( - window.dream_tab.burnin_spin.value() == loaded_settings.burnin_percent + SAXSMainWindow._effective_dream_violin_mode(radius_settings) + == "effective_radii_only" ) - assert float( - window.dream_tab.parameter_map_table.item(0, 4).text() - ) == pytest.approx(older_entries[0].value) assert ( - str(older_bundle.run_dir) in window.dream_tab.output_box.toPlainText() + SAXSMainWindow._effective_dream_violin_mode(additional_settings) + == "additional_parameters_only" ) -def test_dream_model_metrics_box_updates_with_bestfit_method(qapp, tmp_path): - del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) - _write_minimal_dream_results(project_dir) - window = SAXSMainWindow(initial_project_dir=project_dir) - - window.load_latest_results() - axis = window.dream_tab.model_figure.axes[0] - first_metrics = "\n".join(text.get_text() for text in axis.texts) - - window.dream_tab.bestfit_method_combo.setCurrentIndex(2) - QApplication.processEvents() - - axis = window.dream_tab.model_figure.axes[0] - second_metrics = "\n".join(text.get_text() for text in axis.texts) - - assert "RMSE:" in second_metrics - assert "Mean |res|:" in second_metrics - assert "R²:" in second_metrics - assert first_metrics != second_metrics - - -def test_dream_model_plot_includes_residual_subplot(qapp, tmp_path): - del qapp +def test_dream_results_loader_filters_posterior_samples(tmp_path): project_dir, _paths = _build_minimal_saxs_project(tmp_path) - _write_minimal_dream_results(project_dir) - window = SAXSMainWindow(initial_project_dir=project_dir) - - window.load_latest_results() + bundle = _write_minimal_dream_results(project_dir) + loader = SAXSDreamResultsLoader(bundle.run_dir, burnin_percent=0) - assert len(window.dream_tab.model_figure.axes) == 2 - top_axis = window.dream_tab.model_figure.axes[0] - residual_axis = window.dream_tab.model_figure.axes[1] - assert top_axis.get_title().startswith("DREAM refinement:") - assert residual_axis.get_ylabel() == "Residual" - assert residual_axis.get_xlabel() == "q (Å⁻¹)" - assert residual_axis.get_xscale() == top_axis.get_xscale() + summary = loader.get_summary( + bestfit_method="median", + posterior_filter_mode="top_n_logp", + posterior_top_n=1, + credible_interval_low=10.0, + credible_interval_high=90.0, + ) + violin_plot = loader.build_violin_data( + mode="varying_parameters", + posterior_filter_mode="top_percent_logp", + posterior_top_percent=25.0, + sample_source="map_chain_only", + ) - residual_line = residual_axis.get_lines()[-1] - plot_data = window.dream_tab._current_model_plot_data - assert plot_data is not None - expected = np.asarray( - plot_data.model_intensities - plot_data.experimental_intensities, - dtype=float, + assert summary.posterior_filter_mode == "top_n_logp" + assert summary.posterior_sample_count == 1 + assert summary.credible_interval_low == 10.0 + assert summary.credible_interval_high == 90.0 + assert np.allclose( + summary.interval_low_values, + summary.interval_high_values, ) assert np.allclose( - np.asarray(residual_line.get_ydata(), dtype=float), - expected, + summary.bestfit_params, + summary.interval_low_values, ) - window.close() + assert violin_plot.sample_source == "map_chain_only" + assert violin_plot.sample_count == 2 -def test_dream_model_plot_redraw_on_log_x_avoids_nonpositive_xlim_warning( - qapp, +def test_dream_results_loader_reuses_cached_plot_data_across_interval_changes( tmp_path, ): - del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - _write_minimal_dream_results(project_dir) - window = SAXSMainWindow(initial_project_dir=project_dir) - - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - window.load_latest_results() - window.dream_tab.model_log_x_checkbox.setChecked(True) - QApplication.processEvents() - window.dream_tab.bestfit_method_combo.setCurrentIndex(1) - QApplication.processEvents() + bundle = _write_minimal_dream_results(project_dir) + loader = SAXSDreamResultsLoader(bundle.run_dir, burnin_percent=0) - warning_messages = [str(item.message) for item in caught] - assert not any( - "Attempt to set non-positive xlim on a log-scaled axis" in message - for message in warning_messages + first_model = loader.build_model_fit_data( + bestfit_method="median", + credible_interval_low=5.0, + credible_interval_high=95.0, + ) + second_model = loader.build_model_fit_data( + bestfit_method="median", + credible_interval_low=10.0, + credible_interval_high=90.0, + ) + first_violin = loader.build_violin_data( + mode="weights_only", + credible_interval_low=5.0, + credible_interval_high=95.0, + ) + second_violin = loader.build_violin_data( + mode="weights_only", + credible_interval_low=10.0, + credible_interval_high=90.0, ) - window.close() + assert first_model is second_model + assert first_violin is second_violin -def test_prefit_model_metrics_box_updates_with_model_changes(qapp, tmp_path): + +def test_dream_runtime_bundle_carries_selected_solvent_trace(qapp, tmp_path): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - window = SAXSMainWindow(initial_project_dir=project_dir) - - axis = window.prefit_tab.figure.axes[0] - first_metrics = "\n".join(text.get_text() for text in axis.texts) + prefit = SAXSPrefitWorkflow(project_dir) + prefit.save_fit(prefit.parameter_entries) - assert "RMSE:" in first_metrics - assert "Mean |res|:" in first_metrics - assert "R²:" in first_metrics + solvent_q = np.linspace(0.05, 0.3, 8) + solvent_intensity = np.linspace(3.0, 4.4, 8) + solvent_path = tmp_path / "solvent_reference_trace.dat" + np.savetxt(solvent_path, np.column_stack([solvent_q, solvent_intensity])) - window.prefit_tab.set_parameter_row("scale", value=1e-3) - window.update_prefit_model() + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.solvent_data_path = str(solvent_path) + settings.copied_solvent_data_file = None + manager.save_project(settings) - axis = window.prefit_tab.figure.axes[0] - second_metrics = "\n".join(text.get_text() for text in axis.texts) + workflow = SAXSDreamWorkflow(project_dir) + entries = workflow.create_default_parameter_map(persist=True) + bundle = workflow.create_runtime_bundle(entries=entries) + metadata = json.loads(bundle.metadata_path.read_text(encoding="utf-8")) - assert "RMSE:" in second_metrics - assert "Mean |res|:" in second_metrics - assert "R²:" in second_metrics - assert first_metrics != second_metrics - window.close() + assert np.allclose(metadata["solvent_intensities"], solvent_intensity) -def test_dream_violin_scale_modes_and_palette_controls(qapp, tmp_path): +def test_dream_plot_data_exports_save_into_exported_results_data( + qapp, tmp_path +): del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) + project_dir, paths = _build_minimal_saxs_project(tmp_path) _write_minimal_dream_results(project_dir) window = SAXSMainWindow(initial_project_dir=project_dir) - palette_index = window.dream_tab.violin_palette_combo.findData("plasma") + class _AcceptedDialog: + def __init__(self, *args, **kwargs): + del args, kwargs + self.selected_options = SimpleNamespace( + output_dir=paths.exported_data_dir, + base_name="dream_violin_export_test", + save_csv=True, + save_pkl=True, + ) + + def exec(self): + return QDialog.DialogCode.Accepted + + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.DreamViolinExportDialog", + _AcceptedDialog, + ) + window.load_latest_results() - window.dream_tab.violin_value_scale_combo.setCurrentIndex(1) - window.dream_tab.violin_palette_combo.setCurrentIndex(palette_index) - window.dream_tab._configure_plot_color_button( - window.dream_tab.violin_point_color_button, - "tab:blue", - label="Point", + window.dream_tab.violin_mode_combo.setCurrentIndex(2) + window.save_dream_model_fit() + window.save_dream_violin_data() + monkeypatch.undo() + + model_exports = sorted( + paths.exported_data_dir.glob("dream_model_fit_*.csv") + ) + violin_csv_exports = sorted( + paths.exported_data_dir.glob("dream_violin_*.csv") + ) + violin_pkl_exports = sorted( + paths.exported_data_dir.glob("dream_violin_*.pkl") + ) + dialog_csv_export = ( + paths.exported_data_dir / "dream_violin_export_test.csv" + ) + dialog_pkl_export = ( + paths.exported_data_dir / "dream_violin_export_test.pkl" + ) + dialog_metadata_export = ( + paths.exported_data_dir / "dream_violin_export_test.metadata.json" + ) + dialog_report_export = ( + paths.exported_data_dir / "dream_violin_export_test.report.txt" ) - window.dream_tab.visualization_settings_changed.emit() - QApplication.processEvents() - axis = window.dream_tab.violin_figure.axes[0] - tick_labels = [label.get_text() for label in axis.get_xticklabels()] - assert tick_labels == ["w0 (A)"] - assert axis.get_ylabel() == "Weight fraction" - assert axis.get_title() == "Posterior weight distributions" - assert axis.get_ylim() == pytest.approx((0.0, 1.0)) - body = next( - collection - for collection in axis.collections - if isinstance(collection, PolyCollection) + assert model_exports + assert violin_csv_exports + assert violin_pkl_exports + assert dialog_csv_export.is_file() + assert dialog_pkl_export.is_file() + assert dialog_metadata_export.is_file() + assert dialog_report_export.is_file() + assert ( + "q,experimental_intensity,model_intensity," + "solvent_contribution,structure_factor" + ) in model_exports[-1].read_text(encoding="utf-8") + model_metadata = json.loads( + model_exports[-1] + .with_name(f"{model_exports[-1].stem}.metadata.json") + .read_text(encoding="utf-8") ) - assert to_hex(body.get_facecolor()[0], keep_alpha=False) == to_hex( - colormaps["plasma"](0.72), - keep_alpha=False, + model_report = model_exports[-1].with_name( + f"{model_exports[-1].stem}.report.txt" ) - assert to_hex( - axis.collections[-1].get_facecolor()[0], keep_alpha=False - ) == to_hex( - "tab:blue", - keep_alpha=False, + assert model_report.is_file() + assert model_metadata["export_kind"] == "dream_model_fit" + assert model_metadata["model_fit"]["includes_structure_factor"] is True + assert model_metadata["model_fit"]["fit_metrics"]["rmse"] >= 0.0 + assert "Model fit metrics:" in model_report.read_text(encoding="utf-8") + assert "w0 (A)" in dialog_csv_export.read_text(encoding="utf-8") + violin_payload = pickle.loads(dialog_pkl_export.read_bytes()) + violin_metadata = json.loads( + dialog_metadata_export.read_text(encoding="utf-8") + ) + assert violin_payload["violin_plot"]["display_names"] == ["w0 (A)"] + assert violin_payload["screening_metrics"]["posterior_filter_mode"] + assert violin_metadata["export_kind"] == "dream_violin" + assert violin_metadata["screening_metrics"]["posterior_filter_mode"] + assert violin_payload["plot_payload"]["ylabel"] == "Parameter value" + assert np.asarray(violin_payload["plot_payload"]["samples"]).shape[1] == 1 + assert "Posterior violin data:" in dialog_report_export.read_text( + encoding="utf-8" ) - window.dream_tab.violin_value_scale_combo.setCurrentIndex(2) - QApplication.processEvents() - - axis = window.dream_tab.violin_figure.axes[0] - assert axis.get_ylabel() == "Normalized parameter value" - assert axis.get_title() == "Posterior parameter distributions (normalized)" - assert axis.get_ylim() == pytest.approx((0.0, 1.0)) - normalized_labels = [ - label.get_text() - for label in axis.get_xticklabels() - if label.get_text() - ] - assert "w0 (A)" in normalized_labels - assert "solv_w" in normalized_labels - -def test_dream_violin_custom_color_controls_apply_to_plot( +def test_dream_model_report_export_builds_context_and_writes_pptx( qapp, tmp_path, monkeypatch ): del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) + project_dir, paths = _build_minimal_saxs_project(tmp_path) _write_minimal_dream_results(project_dir) window = SAXSMainWindow(initial_project_dir=project_dir) + window.load_latest_results() - chosen_colors = iter( - [ - QColor("#123456"), - QColor("#fedcba"), - QColor("#654321"), - QColor("#abcdef"), - QColor("#111111"), - ] - ) + captured: dict[str, object] = {} + errors: list[tuple[str, str]] = [] monkeypatch.setattr( - "saxshell.saxs.ui.dream_tab.QColorDialog.getColor", - lambda *args, **kwargs: next(chosen_colors), + window, + "_show_error", + lambda title, message: errors.append((title, message)), ) - palette_index = window.dream_tab.violin_palette_combo.findData( - "custom_solid" + def _fake_export(context, *, progress_callback=None): + captured["context"] = context + if progress_callback is not None: + progress_callback(1, 3, "Rendering report plots...") + progress_callback(2, 3, "Building report slides...") + progress_callback(3, 3, "Saving PowerPoint report...") + context.asset_dir.mkdir(parents=True, exist_ok=True) + context.output_path.write_bytes(b"pptx") + manifest_path = context.asset_dir / "report_manifest.json" + manifest_path.write_text("{}\n", encoding="utf-8") + return SimpleNamespace( + report_path=context.output_path, + manifest_path=manifest_path, + figure_paths=(), + ) + + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.export_dream_model_report_pptx", + _fake_export, ) - window.load_latest_results() - window.dream_tab.violin_palette_combo.setCurrentIndex(palette_index) - window.dream_tab._choose_violin_custom_color() - window.dream_tab._choose_violin_point_color() - window.dream_tab._choose_interval_color() - window.dream_tab._choose_median_color() - window.dream_tab._choose_outline_color() - window.dream_tab.violin_outline_width_spin.setValue(1.7) - QApplication.processEvents() - axis = window.dream_tab.violin_figure.axes[0] - body = next( - collection - for collection in axis.collections - if isinstance(collection, PolyCollection) + window.export_dream_model_report() + + assert not errors + report_exports = sorted( + paths.reports_dir.glob("dream_model_report_*.pptx") ) - assert to_hex(body.get_facecolor()[0], keep_alpha=False) == "#123456" - assert to_hex(body.get_edgecolor()[0], keep_alpha=False) == "#111111" - assert body.get_linewidths()[0] == pytest.approx(1.7) - assert ( - to_hex( - axis.collections[-1].get_facecolor()[0], - keep_alpha=False, - ) - == "#fedcba" + assert report_exports + assert report_exports[-1].read_bytes() == b"pptx" + assert "Generating DREAM model report PowerPoint. Please wait..." in ( + window.dream_tab.output_box.toPlainText() ) - line_colors = [ - to_hex(color, keep_alpha=False) - for collection in axis.collections - if isinstance(collection, LineCollection) - for color in collection.get_colors() - ] - assert "#654321" in line_colors - assert "#abcdef" in line_colors - + assert ( + window.dream_tab.progress_label.text() + == "DREAM model report exported." + ) + assert window.dream_tab.progress_bar.maximum() == 3 + assert window.dream_tab.progress_bar.value() == 3 + context = captured["context"] + assert context.project_name == "saxs_project" + assert context.project_dir == project_dir.resolve() + assert context.user_q_range_text == "0.05 to 0.3" + assert context.component_plot_without_solvent is not None + assert context.component_plot_with_solvent is not None + assert len(context.prior_histograms) == 4 + assert context.prefit_parameter_entries + assert context.dream_parameter_map_entries + assert context.dream_summary.posterior_sample_count > 0 + assert context.dream_filter_views + assert "Posterior filter:" in "\n".join(context.output_summary_lines) + assert (context.asset_dir / "report_manifest.json").is_file() + + +def test_dream_model_report_export_writes_real_pptx(qapp, tmp_path): + pytest.importorskip("pptx") + from pptx import Presentation -def test_dream_violin_custom_color_picker_switches_palette_and_updates_plot( - qapp, - tmp_path, - monkeypatch, -): del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) + project_dir, paths = _build_minimal_saxs_project(tmp_path) _write_minimal_dream_results(project_dir) window = SAXSMainWindow(initial_project_dir=project_dir) + window.load_latest_results() - monkeypatch.setattr( - "saxshell.saxs.ui.dream_tab.QColorDialog.getColor", - lambda *args, **kwargs: QColor("#224466"), + context = window._build_dream_model_report_context( + settings=window.dream_tab.settings_payload(), + output_path=paths.reports_dir / "dream_model_report_test.pptx", + asset_dir=paths.reports_dir / "dream_model_report_test_assets", ) - window.load_latest_results() - assert window.dream_tab.violin_palette_combo.currentData() == "Blues" + progress_updates: list[tuple[int, int, str]] = [] - window.dream_tab.violin_custom_color_button.click() - QApplication.processEvents() + result = export_dream_model_report_pptx( + context, + progress_callback=lambda processed, total, message: progress_updates.append( + (processed, total, message) + ), + ) + + assert result.report_path.is_file() + assert result.report_path.stat().st_size > 0 + assert result.manifest_path.is_file() + assert result.figure_paths + figure_names = {path.name for path in result.figure_paths} + assert "dream_filter_violin_comparison.png" in figure_names + assert "dream_filter_violin_comparison_weights.png" in figure_names + assert "dream_filter_violin_comparison_effective_radii.png" in figure_names + assert "prefit_model_without_solvent.png" in figure_names + if ( + context.prefit_evaluation is not None + and context.prefit_evaluation.solvent_contribution is not None + and np.any( + np.isfinite( + np.asarray( + context.prefit_evaluation.solvent_contribution, + dtype=float, + ) + ) + & ( + np.asarray( + context.prefit_evaluation.solvent_contribution, + dtype=float, + ) + > 0.0 + ) + ) + ): + assert "prefit_model_with_solvent.png" in figure_names + assert progress_updates + assert progress_updates[0][0] == 0 + assert "Please wait" in progress_updates[0][2] + assert progress_updates[-1][0] == progress_updates[-1][1] + + presentation = Presentation(str(result.report_path)) + assert presentation.slide_width / 914400 == pytest.approx(13.333, rel=1e-3) + assert presentation.slide_height / 914400 == pytest.approx(7.5, rel=1e-3) + slide_texts = [ + "\n".join( + shape.text + for shape in slide.shapes + if getattr(shape, "has_text_frame", False) + ) + for slide in presentation.slides + ] + for slide in presentation.slides: + for shape in slide.shapes: + assert shape.left >= 0 + assert shape.top >= 0 + assert shape.left + shape.width <= presentation.slide_width + assert shape.top + shape.height <= presentation.slide_height + + title_shape = next( + shape + for shape in presentation.slides[0].shapes + if getattr(shape, "has_text_frame", False) + and "SAXS Model Report" in shape.text + ) + assert title_shape.text_frame.paragraphs[0].runs[0].font.name == "Arial" assert ( - window.dream_tab.violin_palette_combo.currentData() == "custom_solid" + sum( + "Posterior Violin Comparison" in slide_text + for slide_text in slide_texts + ) + == 3 ) - assert window.dream_tab.selected_violin_custom_color() == "#224466" - axis = window.dream_tab.violin_figure.axes[0] - body = next( - collection - for collection in axis.collections - if isinstance(collection, PolyCollection) + assert any( + "Posterior Violin Comparison - Weights" in slide_text + for slide_text in slide_texts ) - assert to_hex(body.get_facecolor()[0], keep_alpha=False) == "#224466" + assert any( + "Posterior Violin Comparison - Effective Radii" in slide_text + for slide_text in slide_texts + ) + if "prefit_model_with_solvent.png" in figure_names: + assert any( + any( + getattr(shape, "has_text_frame", False) + and "Prefit Model With Solvent" in shape.text + for shape in slide.shapes + ) + for slide in presentation.slides + ) -def test_dream_default_violin_palette_starts_with_higher_contrast_color( - qapp, +def test_dream_model_report_exports_model_information_from_template_context( + qapp, tmp_path ): + pytest.importorskip("pptx") + from pptx import Presentation + del qapp - window = SAXSMainWindow() + project_dir, paths = _build_poly_lma_geometry_project( + tmp_path, + template_name=POLY_LMA_HS_TEMPLATE, + ) + window = SAXSMainWindow(initial_project_dir=project_dir) + window.compute_prefit_cluster_geometry() + _write_minimal_dream_results(project_dir) + window.load_latest_results() - colors = window.dream_tab._violin_body_colors(4) - first_color = colors[0] - expected = colormaps.get_cmap("Blues")(0.35) + context = window._build_dream_model_report_context( + settings=window.dream_tab.settings_payload(), + output_path=paths.reports_dir / "dream_model_report_model_info.pptx", + asset_dir=paths.reports_dir / "dream_model_report_model_info_assets", + ) - assert to_hex(first_color, keep_alpha=False) == to_hex( - expected, - keep_alpha=False, + assert context.template_display_name == "pyDREAM Poly LMA Hard-Sphere" + assert context.template_module_path is not None + assert ( + context.template_module_path.name == "template_pydream_poly_lma_hs.py" ) + assert context.model_equation_text is not None + assert "I_model(q)" in context.model_equation_text + assert "phi_solute" in context.model_equation_text + definition_text = "\n".join(context.model_definition_lines) + reference_text = "\n".join(context.model_reference_lines) + assert "Structure Factor:" in definition_text + assert "Parameter Definitions:" in definition_text + assert "phi_solute" in definition_text + assert "Pedersen review" in reference_text + assert "sasview.org/docs/user/models/hardsphere.html" in reference_text -def test_dream_results_loader_can_order_weight_violin_labels_by_structure( - tmp_path, -): - run_dir = _write_weight_order_dream_results(tmp_path) - loader = SAXSDreamResultsLoader(run_dir, burnin_percent=0) + result = export_dream_model_report_pptx(context) - weight_index_plot = loader.build_violin_data( - mode="weights_only", - weight_order="weight_index", - ) - structure_order_plot = loader.build_violin_data( - mode="weights_only", - weight_order="structure_order", + presentation = Presentation(str(result.report_path)) + slide_text = "\n".join( + shape.text + for slide in presentation.slides + for shape in slide.shapes + if getattr(shape, "has_text_frame", False) ) - assert weight_index_plot.parameter_names == ["w2", "w0", "w1"] - assert weight_index_plot.display_names == [ - "w2 (PbI2O)", - "w0 (I2)", - "w1 (Pb2)", - ] - assert structure_order_plot.parameter_names == ["w0", "w1", "w2"] - assert structure_order_plot.display_names == [ - "w0 (I2)", - "w1 (Pb2)", - "w2 (PbI2O)", - ] + assert "Model Information" in slide_text + assert "pyDREAM Poly LMA Hard-Sphere" in slide_text + assert "Model equation:" in slide_text + assert "I_model(q)" in slide_text + assert "Term definitions:" in slide_text + assert "References:" in slide_text + assert "Pedersen review" in slide_text + assert "Pedersen97.pdf" in slide_text -def test_dream_results_loader_filters_posterior_samples(tmp_path): - project_dir, _paths = _build_minimal_saxs_project(tmp_path) - bundle = _write_minimal_dream_results(project_dir) - loader = SAXSDreamResultsLoader(bundle.run_dir, burnin_percent=0) +def test_dream_model_report_uses_selected_secondary_atom_for_solvent_sort_prior_histograms( + qapp, + tmp_path, +): + del qapp + project_dir, paths = _build_minimal_saxs_project(tmp_path) + _write_minimal_dream_results(project_dir) + window = SAXSMainWindow(initial_project_dir=project_dir) + window.load_latest_results() - summary = loader.get_summary( - bestfit_method="median", - posterior_filter_mode="top_n_logp", - posterior_top_n=1, - credible_interval_low=10.0, - credible_interval_high=90.0, + (paths.project_dir / "md_prior_weights.json").write_text( + json.dumps( + { + "origin": "clusters", + "total_files": 5, + "available_elements": ["Pb", "I", "Br", "O"], + "structures": { + "PbI2": { + "motif_A": { + "count": 2, + "weight": 0.4, + "profile_file": "A_no_motif.txt", + "secondary_atom_distributions": { + "Br": {"0": 1, "1": 1}, + "O": {"0": 1, "2": 1}, + }, + } + }, + "Pb2I4": { + "motif_B": { + "count": 3, + "weight": 0.6, + "profile_file": "A_no_motif.txt", + "secondary_atom_distributions": { + "Br": {"0": 1, "3": 2}, + "O": {"1": 1, "2": 2}, + }, + } + }, + }, + }, + indent=2, + ) + + "\n", + encoding="utf-8", ) - violin_plot = loader.build_violin_data( - mode="varying_parameters", - posterior_filter_mode="top_percent_logp", - posterior_top_percent=25.0, - sample_source="map_chain_only", + window.project_setup_tab.apply_cluster_import_data( + ["Pb", "I", "Br", "O"], + [ + { + "structure": "PbI2", + "motif": "motif_A", + "count": 2, + "weight": 0.4, + "atom_fraction_percent": 40.0, + "structure_fraction_percent": 40.0, + }, + { + "structure": "Pb2I4", + "motif": "motif_B", + "count": 3, + "weight": 0.6, + "atom_fraction_percent": 60.0, + "structure_fraction_percent": 60.0, + }, + ], + ) + window.project_setup_tab.prior_mode_combo.setCurrentText( + "Solvent Sort - Structure Fraction" + ) + secondary_index = window.project_setup_tab.secondary_filter_combo.findText( + "O" + ) + assert secondary_index >= 0 + window.project_setup_tab.secondary_filter_combo.setCurrentIndex( + secondary_index + ) + window.project_setup_tab.prior_mode_combo.setCurrentText( + "Structure Fraction" ) - assert summary.posterior_filter_mode == "top_n_logp" - assert summary.posterior_sample_count == 1 - assert summary.credible_interval_low == 10.0 - assert summary.credible_interval_high == 90.0 - assert np.allclose( - summary.interval_low_values, - summary.interval_high_values, + context = window._build_dream_model_report_context( + settings=window.dream_tab.settings_payload(), + output_path=paths.reports_dir + / "dream_model_report_secondary_atom.pptx", + asset_dir=( + paths.reports_dir / "dream_model_report_secondary_atom_assets" + ), ) - assert np.allclose( - summary.bestfit_params, - summary.interval_low_values, + + assert ( + context.powerpoint_settings.solvent_sort_histogram_color_map + == "summer" ) - assert violin_plot.sample_source == "map_chain_only" - assert violin_plot.sample_count == 2 + assert ( + context.prior_histograms[2].mode == "solvent_sort_structure_fraction" + ) + assert context.prior_histograms[2].secondary_element == "O" + assert context.prior_histograms[2].cmap == "summer" + assert context.prior_histograms[3].mode == "solvent_sort_atom_fraction" + assert context.prior_histograms[3].secondary_element == "O" + assert context.prior_histograms[3].cmap == "summer" -def test_dream_runtime_bundle_carries_selected_solvent_trace(qapp, tmp_path): +def test_dream_model_report_context_and_export_honor_powerpoint_settings( + qapp, + tmp_path, +): + pytest.importorskip("pptx") + del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) - prefit = SAXSPrefitWorkflow(project_dir) - prefit.save_fit(prefit.parameter_entries) + project_dir, paths = _build_minimal_saxs_project(tmp_path) + _write_minimal_dream_results(project_dir) + window = SAXSMainWindow(initial_project_dir=project_dir) + window.load_latest_results() + window.current_settings.powerpoint_export_settings = ( + PowerPointExportSettings( + font_family="Courier New", + component_color_map="plasma", + prior_histogram_color_map="cividis", + solvent_sort_histogram_color_map="magma", + include_prior_histograms=False, + include_directory_summary=False, + generate_manifest=False, + export_figure_assets=False, + ) + ) - solvent_q = np.linspace(0.05, 0.3, 8) - solvent_intensity = np.linspace(3.0, 4.4, 8) - solvent_path = tmp_path / "solvent_reference_trace.dat" - np.savetxt(solvent_path, np.column_stack([solvent_q, solvent_intensity])) + context = window._build_dream_model_report_context( + settings=window.dream_tab.settings_payload(), + output_path=paths.reports_dir / "dream_model_report_custom.pptx", + asset_dir=paths.reports_dir / "dream_model_report_custom_assets", + ) - manager = SAXSProjectManager() - settings = manager.load_project(project_dir) - settings.solvent_data_path = str(solvent_path) - settings.copied_solvent_data_file = None - manager.save_project(settings) + assert context.powerpoint_settings.font_family == "Courier New" + assert not context.powerpoint_settings.include_prior_histograms + assert not context.powerpoint_settings.generate_manifest + assert not context.powerpoint_settings.export_figure_assets + assert not any( + "Report assets:" in line or "Report manifest:" in line + for line in context.directory_lines + ) + assert context.prior_histograms[0].cmap == "cividis" + assert context.prior_histograms[2].cmap == "magma" + expected_component_color = to_hex( + colormaps["plasma"](0.68), + keep_alpha=False, + ) + assert ( + context.component_plot_without_solvent.component_series[0].color + == expected_component_color + ) - workflow = SAXSDreamWorkflow(project_dir) - entries = workflow.create_default_parameter_map(persist=True) - bundle = workflow.create_runtime_bundle(entries=entries) - metadata = json.loads(bundle.metadata_path.read_text(encoding="utf-8")) + result = export_dream_model_report_pptx(context) - assert np.allclose(metadata["solvent_intensities"], solvent_intensity) + assert result.report_path.is_file() + assert result.manifest_path is None + assert not result.figure_paths + assert not context.asset_dir.exists() -def test_dream_plot_data_exports_save_into_exported_results_data( - qapp, tmp_path -): +def test_dream_recycle_pushes_selected_best_fit_into_prefit(qapp, tmp_path): del qapp - project_dir, paths = _build_minimal_saxs_project(tmp_path) + project_dir, _paths = _build_minimal_saxs_project(tmp_path) _write_minimal_dream_results(project_dir) window = SAXSMainWindow(initial_project_dir=project_dir) + window.load_latest_results() + update_calls: list[str] = [] + original_update_prefit_model = window.update_prefit_model + + def _record_update_prefit_model(): + update_calls.append("called") + return original_update_prefit_model() + + window.update_prefit_model = _record_update_prefit_model + + settings = window.dream_tab.settings_payload() + summary = window._last_results_loader.get_summary( + bestfit_method=settings.bestfit_method, + posterior_filter_mode=settings.posterior_filter_mode, + posterior_top_percent=settings.posterior_top_percent, + posterior_top_n=settings.posterior_top_n, + credible_interval_low=settings.credible_interval_low, + credible_interval_high=settings.credible_interval_high, + ) + expected_values = { + str(name): float(summary.bestfit_params[index]) + for index, name in enumerate(summary.full_parameter_names) + } - class _AcceptedDialog: - def __init__(self, *args, **kwargs): - del args, kwargs - self.selected_options = SimpleNamespace( - output_dir=paths.exported_data_dir, - base_name="dream_violin_export_test", - save_csv=True, - save_pkl=True, - ) - - def exec(self): - return QDialog.DialogCode.Accepted - - monkeypatch = pytest.MonkeyPatch() - monkeypatch.setattr( - "saxshell.saxs.ui.main_window.DreamViolinExportDialog", - _AcceptedDialog, - ) + scale_row = window.prefit_tab.find_parameter_row("scale") + assert scale_row >= 0 + vary_item = window.prefit_tab.parameter_table.item(scale_row, 4) + assert vary_item is not None + vary_item.setCheckState(Qt.CheckState.Unchecked) + window.prefit_tab.set_parameter_row("scale", value=2e-3) + window.prefit_tab.set_parameter_row("offset", value=0.333) - window.load_latest_results() - window.dream_tab.violin_mode_combo.setCurrentIndex(2) - window.save_dream_model_fit() - window.save_dream_violin_data() - monkeypatch.undo() + window.recycle_dream_output_to_prefit() - model_exports = sorted( - paths.exported_data_dir.glob("dream_model_fit_*.csv") - ) - violin_csv_exports = sorted( - paths.exported_data_dir.glob("dream_violin_*.csv") - ) - violin_pkl_exports = sorted( - paths.exported_data_dir.glob("dream_violin_*.pkl") - ) - dialog_csv_export = ( - paths.exported_data_dir / "dream_violin_export_test.csv" - ) - dialog_pkl_export = ( - paths.exported_data_dir / "dream_violin_export_test.pkl" - ) - dialog_metadata_export = ( - paths.exported_data_dir / "dream_violin_export_test.metadata.json" - ) - dialog_report_export = ( - paths.exported_data_dir / "dream_violin_export_test.report.txt" + entries_by_name = { + entry.name: entry for entry in window.prefit_tab.parameter_entries() + } + assert entries_by_name["scale"].value == pytest.approx( + expected_values["scale"] ) - - assert model_exports - assert violin_csv_exports - assert violin_pkl_exports - assert dialog_csv_export.is_file() - assert dialog_pkl_export.is_file() - assert dialog_metadata_export.is_file() - assert dialog_report_export.is_file() - assert "q,experimental_intensity,model_intensity" in model_exports[ - -1 - ].read_text(encoding="utf-8") - model_metadata = json.loads( - model_exports[-1] - .with_name(f"{model_exports[-1].stem}.metadata.json") - .read_text(encoding="utf-8") + assert entries_by_name["offset"].value == pytest.approx( + expected_values["offset"] ) - model_report = model_exports[-1].with_name( - f"{model_exports[-1].stem}.report.txt" + assert entries_by_name["w0"].value == pytest.approx(expected_values["w0"]) + assert not entries_by_name["scale"].vary + assert window.tabs.currentWidget() is window.prefit_tab + assert update_calls == [] + assert "Recycled DREAM output into Prefit." in ( + window.prefit_tab.output_box.toPlainText() ) - assert model_report.is_file() - assert model_metadata["export_kind"] == "dream_model_fit" - assert model_metadata["model_fit"]["fit_metrics"]["rmse"] >= 0.0 - assert "Model fit metrics:" in model_report.read_text(encoding="utf-8") - assert "w0 (A)" in dialog_csv_export.read_text(encoding="utf-8") - violin_payload = pickle.loads(dialog_pkl_export.read_bytes()) - violin_metadata = json.loads( - dialog_metadata_export.read_text(encoding="utf-8") + assert "Prefit preview refresh: deferred" in ( + window.prefit_tab.output_box.toPlainText() ) - assert violin_payload["violin_plot"]["display_names"] == ["w0 (A)"] - assert violin_payload["screening_metrics"]["posterior_filter_mode"] - assert violin_metadata["export_kind"] == "dream_violin" - assert violin_metadata["screening_metrics"]["posterior_filter_mode"] - assert violin_payload["plot_payload"]["ylabel"] == "Parameter value" - assert np.asarray(violin_payload["plot_payload"]["samples"]).shape[1] == 1 - assert "Posterior violin data:" in dialog_report_export.read_text( - encoding="utf-8" + assert "Recycled the current DREAM best fit into the Prefit tab." in ( + window.dream_tab.output_box.toPlainText() ) @@ -4418,6 +6518,26 @@ def test_recognized_cluster_table_is_resizable_scrollable_and_saves_colors( assert settings.component_trace_colors == {"A_no_motif": "#123456"} +def test_model_and_build_layout_keeps_cluster_table_below_template_fields( + qapp, +): + del qapp + tab = ProjectSetupTab() + + model_layout = tab.model_group.layout() + assert model_layout is not None + assert model_layout.itemAt(0).widget() is tab._model_build_header_widget + assert model_layout.itemAt(1).layout() is tab._model_build_lower_layout + assert ( + tab._model_build_lower_layout.itemAt(0).widget() + is tab._model_build_button_widget + ) + assert ( + tab._model_build_lower_layout.itemAt(1).widget() + is tab._recognized_clusters_group + ) + + def test_open_project_uses_existing_project_field(qapp, tmp_path): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) @@ -4449,49 +6569,349 @@ def test_loaded_project_prefers_original_experimental_reference_in_ui( ), ) - project_dir = tmp_path / "project_with_copy" - settings = manager.create_project(project_dir) - paths = build_project_paths(project_dir) - copied_path = paths.experimental_data_dir / source_path.name - copied_path.write_text( - source_path.read_text(encoding="utf-8"), encoding="utf-8" + project_dir = tmp_path / "project_with_copy" + settings = manager.create_project(project_dir) + paths = build_project_paths(project_dir) + copied_path = paths.experimental_data_dir / source_path.name + copied_path.write_text( + source_path.read_text(encoding="utf-8"), encoding="utf-8" + ) + settings.experimental_data_path = str(source_path) + settings.copied_experimental_data_file = str(copied_path) + manager.save_project(settings) + + window = SAXSMainWindow(initial_project_dir=project_dir) + + assert window.project_setup_tab.experimental_data_edit.text() == str( + source_path + ) + + +def test_prefit_recommended_scale_button_updates_scale_bounds(qapp, tmp_path): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + + scale_row = window.prefit_tab.find_parameter_row("scale") + assert scale_row >= 0 + window.prefit_tab.parameter_table.item(scale_row, 3).setText("1e-6") + window.prefit_tab.parameter_table.item(scale_row, 5).setText("1e-7") + window.prefit_tab.parameter_table.item(scale_row, 6).setText("1e-5") + + window.apply_recommended_scale_settings() + + scale_entry = next( + entry + for entry in window.prefit_tab.parameter_entries() + if entry.name == "scale" + ) + offset_entry = next( + entry + for entry in window.prefit_tab.parameter_entries() + if entry.name == "offset" + ) + assert scale_entry.vary + assert scale_entry.value == pytest.approx(5e-4) + assert scale_entry.minimum == pytest.approx(5e-5) + assert scale_entry.maximum == pytest.approx(5e-3) + assert offset_entry.value == pytest.approx(0.05) + assert "Applied autoscale settings." in ( + window.prefit_tab.output_box.toPlainText() + ) + + +def test_prefit_tab_reorders_controls_and_parameter_actions(qapp): + del qapp + tab = PrefitTab() + + assert tab._parameter_action_layout.itemAt(0).widget() is ( + tab.recommended_scale_button + ) + assert tab._parameter_action_layout.itemAt(1).widget() is ( + tab.update_button + ) + assert tab._parameter_action_layout.itemAt(2).widget() is ( + tab.auto_update_checkbox + ) + assert tab._parameter_action_layout.itemAt(3).widget() is ( + tab.scrollable_parameter_checkbox + ) + assert tab._prefit_control_button_grid.itemAtPosition(0, 0).widget() is ( + tab._run_button_cell + ) + assert tab._prefit_control_button_grid.itemAtPosition(0, 1).widget() is ( + tab.autosave_checkbox + ) + assert tab._prefit_control_button_grid.itemAtPosition(1, 0).widget() is ( + tab.save_button + ) + assert tab._prefit_control_button_grid.itemAtPosition(1, 1).widget() is ( + tab.reset_button + ) + assert tab._prefit_control_button_grid.itemAtPosition(2, 0).widget() is ( + tab.set_best_button + ) + assert tab._prefit_control_button_grid.itemAtPosition(2, 1).widget() is ( + tab.reset_best_button + ) + + +def test_solute_volume_fraction_help_text_and_labels_omit_fullrmc(qapp): + del qapp + widget = SoluteVolumeFractionWidget() + + assert "fullrmc" not in SOLUTE_VOLUME_FRACTION_HELP_TEXT.lower() + label_text = "\n".join( + label.text().lower() for label in widget.findChildren(QLabel) + ) + assert "fullrmc" not in label_text + + +def test_prefit_tab_auto_updates_model_on_value_change_when_enabled(qapp): + del qapp + tab = PrefitTab() + events: list[str] = [] + tab.update_model_requested.connect(lambda: events.append("update")) + tab.populate_parameter_table( + [ + PrefitParameterEntry( + structure="A", + motif="motif", + name="scale", + value=1.0, + vary=True, + minimum=0.1, + maximum=10.0, + category="fit", + ) + ] + ) + + row = tab.find_parameter_row("scale") + assert row >= 0 + + tab.parameter_table.item(row, 3).setText("2.5") + assert events == [] + + tab.auto_update_checkbox.setChecked(True) + tab.parameter_table.item(row, 3).setText("3.5") + assert events == ["update"] + + tab.set_parameter_row("scale", value=4.5) + assert events == ["update"] + + +def test_prefit_tab_supports_linked_parameter_expressions(qapp): + del qapp + tab = PrefitTab() + events: list[str] = [] + tab.update_model_requested.connect(lambda: events.append("update")) + tab.populate_parameter_table( + [ + PrefitParameterEntry( + structure="A", + motif="motif", + name="scale", + value=2.5, + vary=True, + minimum=0.1, + maximum=10.0, + category="fit", + ), + PrefitParameterEntry( + structure="A", + motif="motif", + name="offset", + value=0.0, + vary=True, + minimum=-10.0, + maximum=10.0, + category="fit", + ), + ] + ) + + tab.auto_update_checkbox.setChecked(True) + offset_row = tab.find_parameter_row("offset") + assert offset_row >= 0 + + tab.parameter_table.item(offset_row, 3).setText("*scale") + + entries = {entry.name: entry for entry in tab.parameter_entries()} + offset_entry = entries["offset"] + vary_item = tab.parameter_table.item(offset_row, 4) + + assert events == ["update"] + assert offset_entry.initial_value_expression == "*scale" + assert offset_entry.value_expression is None + assert offset_entry.value == pytest.approx(2.5) + assert offset_entry.vary is True + assert vary_item.checkState() == Qt.CheckState.Checked + assert vary_item.flags() & Qt.ItemFlag.ItemIsUserCheckable + + +def test_prefit_tab_supports_dependent_parameter_expressions_with_vary_off( + qapp, +): + del qapp + tab = PrefitTab() + tab.populate_parameter_table( + [ + PrefitParameterEntry( + structure="A", + motif="motif", + name="scale", + value=2.5, + vary=True, + minimum=0.1, + maximum=10.0, + category="fit", + ), + PrefitParameterEntry( + structure="A", + motif="motif", + name="offset", + value=0.0, + vary=False, + minimum=-10.0, + maximum=10.0, + category="fit", + ), + ] + ) + + offset_row = tab.find_parameter_row("offset") + assert offset_row >= 0 + + tab.parameter_table.item(offset_row, 3).setText("*scale") + + entries = {entry.name: entry for entry in tab.parameter_entries()} + offset_entry = entries["offset"] + vary_item = tab.parameter_table.item(offset_row, 4) + + assert offset_entry.value_expression == "*scale" + assert offset_entry.initial_value_expression is None + assert offset_entry.value == pytest.approx(2.5) + assert offset_entry.vary is False + assert vary_item.checkState() == Qt.CheckState.Unchecked + assert vary_item.flags() & Qt.ItemFlag.ItemIsUserCheckable + + +def test_prefit_tab_scrollable_parameter_supports_expression_seed_ranges(qapp): + del qapp + tab = PrefitTab() + tab.show() + tab.populate_parameter_table( + [ + PrefitParameterEntry( + structure="A", + motif="motif", + name="scale", + value=2.5, + vary=True, + minimum=0.1, + maximum=10.0, + category="fit", + ), + PrefitParameterEntry( + structure="A", + motif="motif", + name="offset", + value=0.0, + vary=True, + minimum=-10.0, + maximum=10.0, + category="fit", + ), + ] + ) + + tab.auto_update_checkbox.setChecked(True) + tab.scrollable_parameter_checkbox.setChecked(True) + offset_row = tab.find_parameter_row("offset") + assert offset_row >= 0 + tab.parameter_table.setCurrentCell(offset_row, 3) + tab.parameter_table.item(offset_row, 3).setText("*scale") + + assert tab.parameter_scroll_panel.isVisible() + assert tab.parameter_scroll_bar.isEnabled() + assert "Initial expression seed" in tab.parameter_scroll_mode_label.text() + assert ( + "no numeric range" + not in tab.parameter_scroll_name_label.text().lower() ) - settings.experimental_data_path = str(source_path) - settings.copied_experimental_data_file = str(copied_path) - manager.save_project(settings) - window = SAXSMainWindow(initial_project_dir=project_dir) - assert window.project_setup_tab.experimental_data_edit.text() == str( - source_path +def test_prefit_metrics_note_non_positive_model_points(qapp): + del qapp + evaluation = PrefitEvaluation( + q_values=np.asarray([0.1, 0.2, 0.3], dtype=float), + experimental_intensities=np.asarray([1.0, 0.8, 0.6], dtype=float), + model_intensities=np.asarray([1.1, -0.1, 0.5], dtype=float), + residuals=np.asarray([0.1, -0.9, -0.1], dtype=float), ) + metric_lines = PrefitTab._prefit_metric_lines(evaluation) -def test_prefit_recommended_scale_button_updates_scale_bounds(qapp, tmp_path): - del qapp - project_dir, _paths = _build_minimal_saxs_project(tmp_path) - window = SAXSMainWindow(initial_project_dir=project_dir) + assert "Model <= 0 at 1 q-points" in metric_lines - scale_row = window.prefit_tab.find_parameter_row("scale") + +def test_prefit_tab_scrollable_parameter_uses_bounds_and_updates_value(qapp): + tab = PrefitTab() + tab.show() + qapp.processEvents() + events: list[str] = [] + tab.update_model_requested.connect(lambda: events.append("update")) + tab.populate_parameter_table( + [ + PrefitParameterEntry( + structure="A", + motif="motif", + name="scale", + value=1e-4, + vary=True, + minimum=1e-5, + maximum=5e-3, + category="fit", + ), + PrefitParameterEntry( + structure="A", + motif="motif", + name="offset", + value=0.0, + vary=True, + minimum=-20.0, + maximum=30.0, + category="fit", + ), + ] + ) + + assert not tab.scrollable_parameter_checkbox.isEnabled() + tab.auto_update_checkbox.setChecked(True) + assert tab.scrollable_parameter_checkbox.isEnabled() + tab.scrollable_parameter_checkbox.setChecked(True) + + scale_row = tab.find_parameter_row("scale") + offset_row = tab.find_parameter_row("offset") assert scale_row >= 0 - window.prefit_tab.parameter_table.item(scale_row, 3).setText("1e-6") - window.prefit_tab.parameter_table.item(scale_row, 5).setText("1e-7") - window.prefit_tab.parameter_table.item(scale_row, 6).setText("1e-5") + assert offset_row >= 0 - window.apply_recommended_scale_settings() + tab.parameter_table.setCurrentCell(scale_row, 3) + qapp.processEvents() + assert tab.parameter_scroll_panel.isVisible() + assert tab.parameter_scroll_mode_label.text() == "Log scroll" - scale_entry = next( - entry - for entry in window.prefit_tab.parameter_entries() - if entry.name == "scale" - ) - assert scale_entry.vary - assert scale_entry.value == pytest.approx(5e-4) - assert scale_entry.minimum == pytest.approx(5e-5) - assert scale_entry.maximum == pytest.approx(5e-3) - assert "Applied autoscale settings." in ( - window.prefit_tab.output_box.toPlainText() - ) + tab.parameter_scroll_bar.setValue(tab.PARAMETER_SCROLL_RESOLUTION) + qapp.processEvents() + assert events + assert float( + tab.parameter_table.item(scale_row, 3).text() + ) == pytest.approx(5e-3) + + tab.parameter_table.setCurrentCell(offset_row, 3) + qapp.processEvents() + assert tab.parameter_scroll_mode_label.text() == "Linear scroll" def test_run_prefit_keeps_manual_weight_value_outside_previous_bounds( @@ -4575,6 +6995,45 @@ def test_best_prefit_preset_saves_resets_and_reloads(qapp, tmp_path): assert reloaded_entries["offset"].value == pytest.approx(0.125) +def test_individual_prefit_parameter_reset_button_restores_template_default( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + + window.prefit_tab.set_parameter_row( + "scale", + value=2e-3, + minimum=1e-6, + maximum=9e-3, + vary=True, + ) + window.prefit_tab.set_parameter_row("offset", value=0.125) + + scale_row = window.prefit_tab.find_parameter_row("scale") + assert scale_row >= 0 + reset_col = _table_column_index(window.prefit_tab.parameter_table, "Reset") + reset_button = window.prefit_tab.parameter_table.cellWidget( + scale_row, + reset_col, + ) + assert reset_button is not None + + reset_button.click() + + entries = { + entry.name: entry for entry in window.prefit_tab.parameter_entries() + } + assert entries["scale"].value == pytest.approx(5e-4) + assert entries["scale"].minimum == pytest.approx(1e-5) + assert entries["scale"].maximum == pytest.approx(5e-3) + assert entries["scale"].vary is False + assert entries["offset"].value == pytest.approx(0.125) + window.close() + + def test_restore_prefit_state_recovers_saved_parameters_and_run_config( qapp, tmp_path ): @@ -5230,6 +7689,88 @@ def test_save_project_state_reloads_prefit_and_dream_for_reduced_q_range( window.close() +def test_save_project_state_preserves_live_prefit_parameters(qapp, tmp_path): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + + window.prefit_tab.set_parameter_row( + "scale", + value=7e-4, + minimum=1e-5, + maximum=8e-3, + vary=True, + ) + window.prefit_tab.set_parameter_row("offset", value=0.125) + window.set_best_prefit_parameters() + + window.prefit_tab.set_parameter_row("scale", value=2e-3) + window.prefit_tab.set_parameter_row("offset", value=0.333) + window.project_setup_tab.qmin_edit.setText("0.12") + window.project_setup_tab.qmax_edit.setText("0.19") + + window.save_project_state() + + entries = { + entry.name: entry for entry in window.prefit_tab.parameter_entries() + } + assert entries["scale"].value == pytest.approx(2e-3) + assert entries["offset"].value == pytest.approx(0.333) + assert window.prefit_workflow is not None + assert np.allclose( + window.prefit_workflow.evaluate().q_values, + [0.12142857142857144, 0.15714285714285714, 0.19285714285714284], + ) + window.close() + + +def test_save_project_state_updates_workflows_in_place_without_resetting_tabs( + qapp, tmp_path +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + + assert window.prefit_workflow is not None + assert window.dream_workflow is not None + original_prefit_workflow = window.prefit_workflow + original_dream_workflow = window.dream_workflow + + window.prefit_tab.set_parameter_row( + "scale", + value=1.7e-3, + minimum=1e-5, + maximum=8e-3, + vary=True, + ) + window.prefit_tab.set_parameter_row("offset", value=0.222) + window.dream_tab.iterations_spin.setValue(4321) + window.dream_tab.nseedchains_spin.setValue(33) + window.project_setup_tab.qmin_edit.setText("0.12") + window.project_setup_tab.qmax_edit.setText("0.19") + + window.save_project_state() + + assert window.prefit_workflow is original_prefit_workflow + assert window.dream_workflow is original_dream_workflow + entries = { + entry.name: entry for entry in window.prefit_tab.parameter_entries() + } + assert entries["scale"].value == pytest.approx(1.7e-3) + assert entries["offset"].value == pytest.approx(0.222) + assert window.dream_tab.settings_payload().niterations == 4321 + assert window.dream_tab.settings_payload().nseedchains == 33 + assert np.allclose( + window.prefit_workflow.evaluate().q_values, + [0.12142857142857144, 0.15714285714285714, 0.19285714285714284], + ) + assert np.allclose( + window.dream_workflow.prefit_workflow.evaluate().q_values, + [0.12142857142857144, 0.15714285714285714, 0.19285714285714284], + ) + window.close() + + def test_save_project_state_warns_when_q_range_expands_beyond_components( qapp, tmp_path, monkeypatch ): @@ -5261,6 +7802,37 @@ def test_save_project_state_warns_when_q_range_expands_beyond_components( window.close() +def test_build_components_warns_when_q_range_is_still_default( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + warnings: list[tuple[str, str]] = [] + start_calls: list[str] = [] + + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.QMessageBox.warning", + lambda _parent, title, message, *args, **kwargs: warnings.append( + (title, message) + ) + or QMessageBox.StandardButton.No, + ) + monkeypatch.setattr( + window, + "_start_project_task", + lambda task_name, *args, **kwargs: start_calls.append(task_name), + ) + + window.build_project_components() + + assert warnings + assert warnings[0][0] == "Build SAXS components with default q-range?" + assert "full experimental-data default" in warnings[0][1] + assert not start_calls + window.close() + + def test_save_project_state_ignores_tiny_q_range_edge_mismatch( qapp, tmp_path, monkeypatch ): @@ -5818,16 +8390,34 @@ def test_dream_plot_trace_toggles_control_visible_series(qapp, tmp_path): assert window.dream_tab.show_experimental_trace_checkbox.isEnabled() assert window.dream_tab.show_model_trace_checkbox.isEnabled() assert window.dream_tab.show_solvent_trace_checkbox.isEnabled() + assert window.dream_tab.show_structure_factor_trace_checkbox.isEnabled() top_axis = window.dream_tab.model_figure.axes[0] - labels = [line.get_label() for line in top_axis.get_lines()] + labels = [ + line.get_label() + for axis in window.dream_tab.model_figure.axes + for line in axis.get_lines() + ] assert "Solvent contribution" not in labels + assert "Structure factor S(q)" not in labels window.dream_tab.show_solvent_trace_checkbox.setChecked(True) top_axis = window.dream_tab.model_figure.axes[0] - labels = [line.get_label() for line in top_axis.get_lines()] + labels = [ + line.get_label() + for axis in window.dream_tab.model_figure.axes + for line in axis.get_lines() + ] assert "Solvent contribution" in labels + window.dream_tab.show_structure_factor_trace_checkbox.setChecked(True) + labels = [ + line.get_label() + for axis in window.dream_tab.model_figure.axes + for line in axis.get_lines() + ] + assert "Structure factor S(q)" in labels + window.dream_tab.show_experimental_trace_checkbox.setChecked(False) top_axis = window.dream_tab.model_figure.axes[0] collection_labels = [ @@ -5869,21 +8459,131 @@ def test_prefit_plot_trace_toggles_control_visible_series(qapp, tmp_path): assert window.prefit_tab.show_experimental_trace_checkbox.isEnabled() assert window.prefit_tab.show_model_trace_checkbox.isEnabled() assert window.prefit_tab.show_solvent_trace_checkbox.isEnabled() + assert window.prefit_tab.show_structure_factor_trace_checkbox.isEnabled() top_axis = window.prefit_tab.figure.axes[0] - labels = [line.get_label() for line in top_axis.get_lines()] + labels = [ + line.get_label() + for axis in window.prefit_tab.figure.axes + for line in axis.get_lines() + ] assert "Solvent contribution" not in labels + assert "Structure factor S(q)" not in labels window.prefit_tab.show_solvent_trace_checkbox.setChecked(True) top_axis = window.prefit_tab.figure.axes[0] - labels = [line.get_label() for line in top_axis.get_lines()] + labels = [ + line.get_label() + for axis in window.prefit_tab.figure.axes + for line in axis.get_lines() + ] assert "Solvent contribution" in labels + window.prefit_tab.show_structure_factor_trace_checkbox.setChecked(True) + labels = [ + line.get_label() + for axis in window.prefit_tab.figure.axes + for line in axis.get_lines() + ] + assert "Structure factor S(q)" in labels + window.prefit_tab.show_experimental_trace_checkbox.setChecked(False) top_axis = window.prefit_tab.figure.axes[0] line_labels = [line.get_label() for line in top_axis.get_lines()] assert "Experimental" not in line_labels + +def test_prefit_field_interaction_warns_before_components_are_built( + qapp, tmp_path, monkeypatch +): + project_dir = tmp_path / "prefit_warning_project" + manager = SAXSProjectManager() + settings = manager.create_project(project_dir) + settings.selected_model_template = ( + "template_pd_likelihood_monosq_decoupled" + ) + manager.save_project(settings) + + window = SAXSMainWindow(initial_project_dir=project_dir) + window.show() + qapp.processEvents() + + warnings: list[tuple[str, str]] = [] + monkeypatch.setattr( + "saxshell.saxs.ui.main_window.QMessageBox.warning", + lambda _parent, title, message, *args, **kwargs: warnings.append( + (title, message) + ) + or QMessageBox.StandardButton.Ok, + ) + + QTest.mouseClick( + window.prefit_tab.parameter_table.viewport(), + Qt.MouseButton.LeftButton, + ) + qapp.processEvents() + QTest.mouseClick( + window.prefit_tab.parameter_table.viewport(), + Qt.MouseButton.LeftButton, + ) + qapp.processEvents() + + assert warnings + assert warnings[0][0] == "Build SAXS components first" + assert "Project Setup tab" in warnings[0][1] + assert len(warnings) == 1 + window.close() + + +def test_model_only_mode_disables_fit_controls_and_dream(qapp, tmp_path): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + dream_index = window.tabs.indexOf(window.dream_tab) + + assert window.tabs.isTabEnabled(dream_index) + assert window.project_setup_tab.experimental_file_button.isEnabled() + assert window.project_setup_tab.experimental_data_edit.isEnabled() + assert window.prefit_tab.run_button.isEnabled() + + window.project_setup_tab.model_only_mode_checkbox.setChecked(True) + + assert window.current_settings is not None + assert window.current_settings.model_only_mode is True + assert not window.project_setup_tab.experimental_file_button.isEnabled() + assert not window.project_setup_tab.experimental_data_edit.isEnabled() + assert not window.project_setup_tab.solvent_file_button.isEnabled() + assert not window.project_setup_tab.solvent_data_edit.isEnabled() + assert ( + not window.project_setup_tab.experimental_trace_visible_checkbox.isEnabled() + ) + assert ( + not window.project_setup_tab.solvent_trace_visible_checkbox.isEnabled() + ) + assert not window.prefit_tab.method_combo.isEnabled() + assert not window.prefit_tab.nfev_spin.isEnabled() + assert not window.prefit_tab.run_button.isEnabled() + assert not window.prefit_tab.recommended_scale_button.isEnabled() + assert window.prefit_tab.update_button.isEnabled() + assert not window.tabs.isTabEnabled(dream_index) + + assert len(window.prefit_tab.figure.axes) == 1 + top_axis = window.prefit_tab.figure.axes[0] + line_labels = [line.get_label() for line in top_axis.get_lines()] + assert "Experimental" not in line_labels + + saved_settings = SAXSProjectManager().load_project(project_dir) + assert saved_settings.model_only_mode is True + + window.project_setup_tab.model_only_mode_checkbox.setChecked(False) + + assert window.current_settings is not None + assert window.current_settings.model_only_mode is False + assert window.project_setup_tab.experimental_file_button.isEnabled() + assert window.project_setup_tab.experimental_data_edit.isEnabled() + assert window.prefit_tab.run_button.isEnabled() + assert window.tabs.isTabEnabled(dream_index) + window.prefit_tab.show_model_trace_checkbox.setChecked(False) top_axis = window.prefit_tab.figure.axes[0] line_labels = [line.get_label() for line in top_axis.get_lines()] @@ -5915,7 +8615,7 @@ def test_save_prefit_plot_data_exports_csv_with_metadata( assert "# fit_metrics:" in contents assert ( "q,experimental_intensity,model_intensity,residual," - "solvent_intensity,solvent_contribution" in contents + "solvent_intensity,solvent_contribution,structure_factor" in contents ) @@ -5943,7 +8643,7 @@ def test_save_prefit_plot_data_exports_npy_with_metadata_sidecar( assert metadata_path.is_file() matrix = np.load(export_path) metadata = json.loads(metadata_path.read_text(encoding="utf-8")) - assert matrix.shape[1] == 6 + assert matrix.shape[1] == 7 assert metadata["columns"] == [ "q", "experimental_intensity", @@ -5951,6 +8651,7 @@ def test_save_prefit_plot_data_exports_npy_with_metadata_sidecar( "residual", "solvent_intensity", "solvent_contribution", + "structure_factor", ] assert "fit_conditions" in metadata assert "fit_metrics" in metadata @@ -5965,6 +8666,14 @@ def test_experimental_status_note_is_wrapped_and_multiline(qapp): assert "\n" in tab.data_status_label.text() +def test_project_setup_template_controls_have_expanded_minimum_width(qapp): + del qapp + tab = ProjectSetupTab() + + assert tab.template_combo.minimumWidth() >= 420 + assert tab.active_template_edit.minimumWidth() >= 420 + + def test_project_setup_uses_scrollable_resizable_side_panes(qapp): del qapp tab = ProjectSetupTab() @@ -6095,6 +8804,59 @@ def test_use_experimental_grid_crops_to_nearest_available_q_values(tmp_path): assert np.allclose(q_grid, [0.08, 0.12, 0.18]) +def test_model_only_q_grid_uses_configured_range_without_experimental_data( + tmp_path, +): + manager = SAXSProjectManager() + settings = ProjectSettings( + project_name="demo_project", + project_dir=str(tmp_path / "demo_project"), + model_only_mode=True, + use_experimental_grid=False, + q_min=0.05, + q_max=0.30, + q_points=8, + ) + + q_grid = manager._build_q_grid(settings, None) + + assert np.allclose(q_grid, np.linspace(0.05, 0.30, 8)) + + +def test_project_settings_roundtrip_preserves_powerpoint_export_settings( + tmp_path, +): + manager = SAXSProjectManager() + settings = ProjectSettings( + project_name="demo_project", + project_dir=str(tmp_path / "demo_project"), + powerpoint_export_settings=PowerPointExportSettings( + font_family="Courier New", + component_color_map="plasma", + prior_histogram_color_map="cividis", + solvent_sort_histogram_color_map="magma", + generate_manifest=False, + export_figure_assets=False, + ), + ) + + manager.save_project(settings) + loaded = manager.load_project(settings.project_dir) + + assert loaded.powerpoint_export_settings.font_family == "Courier New" + assert loaded.powerpoint_export_settings.component_color_map == "plasma" + assert ( + loaded.powerpoint_export_settings.prior_histogram_color_map + == "cividis" + ) + assert ( + loaded.powerpoint_export_settings.solvent_sort_histogram_color_map + == "magma" + ) + assert not loaded.powerpoint_export_settings.generate_manifest + assert not loaded.powerpoint_export_settings.export_figure_assets + + def test_scan_cluster_inventory_reports_progress_and_rows(tmp_path): cluster_dir = tmp_path / "clusters" structure_dir = cluster_dir / "Pb2I4" From f84cfb6c96315fa5a7e119f51a8f89b526bd3833 Mon Sep 17 00:00:00 2001 From: kewh5868 Date: Mon, 30 Mar 2026 11:31:42 -0600 Subject: [PATCH 3/3] Refresh docs and tool integrations --- docs/api/overview.md | 4 + docs/development/repo-structure.md | 6 + docs/getting-started/installation.md | 1 + docs/index.md | 9 +- docs/javascripts/mathjax.js | 15 ++ docs/user-guide/cluster-dynamics.md | 175 +++++++++++++ docs/user-guide/cluster-extraction.md | 29 ++- docs/user-guide/gui-overview.md | 10 + docs/user-guide/preloaded-saxs-models.md | 242 ++++++++++++++++++ docs/user-guide/pydream-workflow.md | 10 +- docs/user-guide/saxs-prefit.md | 255 +++++++++++++++++-- docs/user-guide/template-system.md | 15 ++ mkdocs.yml | 8 + pyproject.toml | 3 + requirements/pip.txt | 1 + src/saxshell/bondanalysis/ui/main_window.py | 8 + src/saxshell/cluster/ui/definitions_panel.py | 152 +++++++++++ src/saxshell/cluster/ui/main_window.py | 8 + src/saxshell/fullrmc/__init__.py | 26 +- src/saxshell/fullrmc/ui/main_window.py | 8 + src/saxshell/mdtrajectory/cli.py | 3 + src/saxshell/mdtrajectory/ui/main_window.py | 11 + src/saxshell/mdtrajectory/workflow.py | 61 ++++- src/saxshell/xyz2pdb/cli.py | 3 + src/saxshell/xyz2pdb/ui/main_window.py | 13 + 25 files changed, 1046 insertions(+), 30 deletions(-) create mode 100644 docs/javascripts/mathjax.js create mode 100644 docs/user-guide/cluster-dynamics.md create mode 100644 docs/user-guide/preloaded-saxs-models.md diff --git a/docs/api/overview.md b/docs/api/overview.md index 8c7cdbe..7bde5e1 100644 --- a/docs/api/overview.md +++ b/docs/api/overview.md @@ -14,6 +14,10 @@ classes that are most likely to be imported directly. - `saxshell.cluster.workflow.ClusterWorkflow` +### Cluster dynamics + +- `saxshell.clusterdynamics.workflow.ClusterDynamicsWorkflow` + ### Bond analysis - `saxshell.bondanalysis.workflow.BondAnalysisWorkflow` diff --git a/docs/development/repo-structure.md b/docs/development/repo-structure.md index 788e8be..24d651b 100644 --- a/docs/development/repo-structure.md +++ b/docs/development/repo-structure.md @@ -8,6 +8,7 @@ This page is a short orientation guide for contributors. src/saxshell/ bondanalysis/ cluster/ + clusterdynamics/ fullrmc/ mdtrajectory/ saxs/ @@ -32,6 +33,11 @@ Residue-aware XYZ-to-PDB conversion, reference-library helpers, and UI code. Cluster extraction workflows and UI. +### `src/saxshell/clusterdynamics` + +Time-binned cluster-distribution analysis, lifetime/rate summaries, dataset +save/load helpers, and the matching UI. + ### `src/saxshell/bondanalysis` Bond-pair and angle-analysis workflows and UI. diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md index ed21497..e802b6c 100644 --- a/docs/getting-started/installation.md +++ b/docs/getting-started/installation.md @@ -48,6 +48,7 @@ The current package exposes these top-level tools: - `saxshell` - `mdtrajectory` - `clusters` +- `clusterdynamics` - `bondanalysis` - `xyz2pdb` - `saxs` diff --git a/docs/index.md b/docs/index.md index 1baed8f..0e5790d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -34,10 +34,11 @@ The current repo supports an end-to-end path that usually looks like this: 1. Inspect and split trajectories with `mdtrajectory`. 2. Optionally convert XYZ frames to residue-aware PDB files with `xyz2pdb`. 3. Extract clusters with `clusters`. -4. Measure bond and angle distributions with `bondanalysis`. -5. Build a SAXS project with `saxs`. -6. Refine the project in **SAXS Prefit** and, if needed, run **pyDREAM**. -7. Use the resulting distributions and selected structures in downstream tools +4. Analyze time-dependent cluster populations with `clusterdynamics`. +5. Measure bond and angle distributions with `bondanalysis`. +6. Build a SAXS project with `saxs`. +7. Refine the project in **SAXS Prefit** and, if needed, run **pyDREAM**. +8. Use the resulting distributions and selected structures in downstream tools such as `fullrmc`. ## Documentation map diff --git a/docs/javascripts/mathjax.js b/docs/javascripts/mathjax.js new file mode 100644 index 0000000..f320319 --- /dev/null +++ b/docs/javascripts/mathjax.js @@ -0,0 +1,15 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true, + }, + options: { + skipHtmlTags: ["script", "noscript", "style", "textarea", "pre"], + }, +}; + +document$.subscribe(() => { + MathJax.typesetPromise?.(); +}); diff --git a/docs/user-guide/cluster-dynamics.md b/docs/user-guide/cluster-dynamics.md new file mode 100644 index 0000000..f58fcd3 --- /dev/null +++ b/docs/user-guide/cluster-dynamics.md @@ -0,0 +1,175 @@ +# Cluster Dynamics + +`clusterdynamics` is the time-resolved companion to `clusters`. It analyzes the +same extracted XYZ or PDB frame folders from `mdtrajectory`, reuses the same +cluster-definition and pair-cutoff logic, and converts the per-frame cluster +counts into: + +- a time-binned cluster-distribution heatmap +- an optional lower subplot from a CP2K `.ener` file +- a sortable lifetime table by stoichiometry label +- reloadable JSON/CSV datasets for later plotting + +## Inputs + +The application expects: + +- an extracted frame folder produced by `mdtrajectory` +- optional CP2K `.ener` data for the lower overlay subplot +- the same atom-type definitions, pair cutoffs, shell options, and PBC/search + settings you would use in `clusters` + +The left pane includes all of the cluster-rule controls from the cluster +extraction UI, plus the time-axis controls and dataset save/load actions. + +## Time Axis Rules + +The time axis is resolved in this order: + +1. `mdtrajectory_export.json` if the frame folder came from a recent + `mdtrajectory` export. This is the most reliable source because it stores + the original frame indices and frame times written during export. +2. `frame_.xyz` or `frame_.pdb` filenames multiplied by the + user-specified frame timestep. +3. A sequential fallback that starts from the folder/start-time field. + +The frame timestep field defaults to `0.5 fs`. + +The heatmap binning now uses an integer `frames / colormap timestep` control. +The UI shows the derived `colormap timestep used (fs)` field so the effective +heatmap timestep is always an exact multiple of the frame timestep. + +`mdtrajectory` cutoff exports now use folder names such as `splitxyz_f847fs`. +The `f847fs` part records the cutoff time in femtoseconds and is auto-filled +into the folder/start-time field when it is detected. + +### Important example + +If the folder is `splitxyz_f847fs`, the frame timestep is `0.5 fs`, and the +first file is `frame_1866.xyz`, the resolved frame time is: + +```text +1866 x 0.5 fs = 933 fs +``` + +In that case the preview warns that the folder cutoff/start time (`847 fs`) and +the first resolved extracted-frame time (`933 fs`) are different. The heatmap +and lifetime calculations follow the resolved frame times, not the folder label. + +## Typical UI Workflow + +1. Load the extracted XYZ/PDB frames folder. +2. Optionally load a CP2K `.ener` file. +3. Confirm the active SAXSShell project if you want saved datasets to default + into `exported_results/data/clusterdynamics`. +4. Enter the cluster definitions, pair cutoffs, shell options, and PBC/search + settings. +5. Confirm the frame timestep, frames per colormap timestep, derived + colormap timestep, and analysis start/stop window. +6. Run the analysis. +7. Adjust the heatmap display mode, time units, colormap, quantile limits, and + optional overlay interactively. +8. Use the **Saved Results** panel to save the current result or reopen a + previously saved dataset without rerunning the frame analysis. + +If the tool is launched from the main SAXS UI, it inherits the active project +directory automatically. + +## Saved Outputs + +The save action writes a JSON dataset plus companion CSV files beside it: + +- `*_cluster_distribution.csv` +- `*_lifetime.csv` +- `*_energy.csv` when an energy overlay is present + +The JSON file is the reloadable artifact. It stores the plotted matrices, +summary tables, time-axis metadata, and optional energy data needed to reopen +the analysis. In the UI, these controls live in the **Saved Results** panel so +they are separate from the **Run Analysis** controls. + +## Heatmap Data + +Each heatmap row is a stoichiometry label, and the full y-axis spans the labels +observed across all time bins in the current analysis window. + +The display mode can be switched interactively between: + +- raw counts per bin +- fraction of all clusters in the bin +- mean count per sampled frame in the bin + +The color scaling uses quantile limits instead of fixed min/max values so large +outliers do not flatten the rest of the heatmap. + +## Association, Dissociation, and Lifetimes + +The lifetime table is computed from the per-label count series over consecutive +sampled frames. + +- Association events: every positive increase in the count for a label between + two adjacent sampled frames. +- Dissociation events: every negative decrease in that count between adjacent + sampled frames. +- Association rate: `association_events / observation_window_ps` +- Dissociation rate: `dissociation_events / observation_window_ps` + +This is a count-based kinetic summary. It does not claim atom-by-atom identity +tracking across frames. Instead, it reports how the occupancy of each +stoichiometry label changes between samples. + +### Lifetime definition + +- Completed lifetime: a cluster instance that appears after the first sampled + frame and disappears before the end of the observation window. +- Window-truncated lifetime: a cluster instance that was already present in the + first sampled frame or was still present in the last sampled frame. + +Mean and standard deviation lifetimes are computed from completed lifetimes +only. The window-truncated count is reported separately so you can see how much +of the series is clipped by the analysis boundaries. + +## Lifetime Table Columns + +- `Label`: stoichiometry label, such as `Pb2I` +- `Size`: total number of atoms represented by the stoichiometry label +- `Mean lifetime (fs)`: average duration of completed lifetimes for that label +- `Std lifetime (fs)`: standard deviation of completed lifetimes +- `Completed`: number of completed lifetimes used in the mean/std calculation +- `Window-truncated`: number of lifetimes clipped by the start or end of the + sampled window +- `Assoc. rate (1/ps)`: positive count changes per picosecond over the selected + analysis window +- `Dissoc. rate (1/ps)`: negative count changes per picosecond over the + selected analysis window +- `Occupancy (%)`: fraction of sampled frames in which at least one cluster of + that label was present +- `Mean count/frame`: average number of clusters with that label per sampled + frame + +Sort the `Lifetime` tab by the `Size` column if you want the older “lifetime by +size” view without a separate tab. + +## Cluster-Distribution CSV Columns + +The saved `*_cluster_distribution.csv` file stores one row per +`label x time-bin` combination with: + +- the label and cluster size +- bin index, start, stop, and center in femtoseconds +- raw count in the bin +- fraction in the bin +- mean count per sampled frame in the bin +- number of sampled frames in that bin +- total clusters in that bin + +## When Warnings Appear + +The preview and summary can include warnings when: + +- the folder/start-time tag such as `_f847fs` was not found +- `mdtrajectory_export.json` is missing or incomplete +- the folder/start-time tag and the resolved frame times disagree + +These warnings are informational. The analysis still runs, but the UI makes the +time basis explicit so the heatmap and lifetime interpretation stay transparent. diff --git a/docs/user-guide/cluster-extraction.md b/docs/user-guide/cluster-extraction.md index 10127cb..dd68204 100644 --- a/docs/user-guide/cluster-extraction.md +++ b/docs/user-guide/cluster-extraction.md @@ -8,8 +8,10 @@ In this repository, that bridge spans more than one tool. 1. Use `mdtrajectory` to inspect a trajectory and export frames. 2. Optionally use `xyz2pdb` if you need molecule-aware PDB frames. 3. Use `clusters` to extract stoichiometry-sorted cluster files. -4. Use `bondanalysis` to measure bond or angle distributions on those clusters. -5. Feed the resulting cluster folder into the SAXS project. +4. Use `clusterdynamics` to build time-dependent cluster-distribution heatmaps + and lifetime tables from the extracted frames. +5. Use `bondanalysis` to measure bond or angle distributions on those clusters. +6. Feed the resulting cluster folder into the SAXS project. ## `mdtrajectory` @@ -19,6 +21,8 @@ This tool is responsible for: - optionally reading CP2K `.ener` files - suggesting a cutoff - exporting selected frames into a sibling folder +- writing `mdtrajectory_export.json` metadata beside the exported frames so + downstream tools can recover the original frame indices and times Example: @@ -28,6 +32,10 @@ mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 ``` +When a cutoff is applied, the default folder name now uses the form +`splitxyz_f847fs` or `splitpdb_f847fs`, where the `f847fs` portion records the +cutoff time in femtoseconds. + ## `xyz2pdb` Use this only when residue identity matters downstream. @@ -54,6 +62,23 @@ The cluster workflow supports both UI and CLI usage. Its CLI exposes separate The CLI help text explicitly calls out faster neighbor search modes such as `kdtree` and `vectorized`. +## `clusterdynamics` + +This application consumes the extracted XYZ or PDB frames from `mdtrajectory` +and applies the same cluster definitions and pair-cutoff rules used by +`clusters`, but bins the results over time instead of writing one +stoichiometry-folder export. + +Key outputs: + +- time-binned cluster-distribution heatmaps +- optional CP2K `.ener` overlays aligned to the same time axis +- a sortable lifetime table by stoichiometry label +- saved JSON/CSV datasets that can be reopened later for plotting + +See [Cluster Dynamics](cluster-dynamics.md) for the full workflow, timing +rules, and the definitions of the lifetime/rate columns. + ## `bondanalysis` Bond analysis is downstream of cluster extraction. Use it to derive bond-pair diff --git a/docs/user-guide/gui-overview.md b/docs/user-guide/gui-overview.md index 32730f2..7f4d302 100644 --- a/docs/user-guide/gui-overview.md +++ b/docs/user-guide/gui-overview.md @@ -19,6 +19,12 @@ PDB conversion before clustering. Use this when you need to build cluster-network exports from extracted frames. +### `clusterdynamics` + +Use this when you need time-binned cluster-distribution heatmaps, optional +energy overlays, and lifetime / association / dissociation summaries from an +extracted XYZ or PDB frame folder. + ### `bondanalysis` Use this when you need bond-pair and angle-distribution measurements on the @@ -73,3 +79,7 @@ Several newer SAXS UI surfaces follow the same patterns: TODO: add screenshots once the docs site has a stable asset pipeline and the UI labels settle after the current SAXS workflow changes. + +??? note "Artwork Attribution" +The SAXSShell application icon used in the SAXS UI is based on artwork +generated with ChatGPT (OpenAI). diff --git a/docs/user-guide/preloaded-saxs-models.md b/docs/user-guide/preloaded-saxs-models.md new file mode 100644 index 0000000..edc6c96 --- /dev/null +++ b/docs/user-guide/preloaded-saxs-models.md @@ -0,0 +1,242 @@ +# Pre-loaded SAXS Models + +This page documents the bundled SAXS templates that ship with SAXSShell. The +equations below describe the **implemented forward models in the repository**. +In a few places, SAXSShell combines MD-derived component mixtures with +literature structure-factor building blocks, so the exact code path is an +implementation of the cited ideas rather than a verbatim reproduction of a +single paper. + +## Template Catalog + +| Template file | GUI name | Status | Model family | +| -------------------------------------------- | ------------------------------------------------------ | ---------- | -------------------------------------------- | +| `template_pydream_monosq_normalized.py` | `pyDREAM MonoSQ Normalized` | current | MonoSQ hard-sphere | +| `template_pydream_poly_lma_hs.py` | `pyDREAM Poly LMA Hard-Sphere` | current | sphere-only Poly LMA hard-sphere | +| `template_pydream_poly_lma_hs_mix_approx.py` | `pyDREAM Poly LMA Hard-Sphere/Ellipsoid Mix (Approx.)` | current | mixed-shape approximate Poly LMA hard-sphere | +| `template_likelihood_monosq.py` | `MonoSQ Basic (archived)` | archived | MonoSQ hard-sphere | +| `template_pd_likelihood_monosq.py` | `MonoSQ PD (archived)` | archived | MonoSQ hard-sphere | +| `template_pd_likelihood_monosq_decoupled.py` | `MonoSQ Decoupled (archived)` | archived | MonoSQ hard-sphere | +| `template_pydream_poly_lma_hs_legacy.py` | `pyDREAM Poly LMA Hard-Sphere (deprecated)` | deprecated | mixed-shape approximate Poly LMA hard-sphere | + +## Shared Notation + +Across the bundled templates: + +- \(q\) is the scattering vector magnitude. +- \(I_i(q)\) is the MD-derived SAXS profile for component \(i\). +- \(I\_{\mathrm{solv}}(q)\) is the solvent scattering trace. +- \(w_i\) is the raw weight assigned to component \(i\). +- \(S\_{\mathrm{HS}}(q; R, \phi)\) is the hard-sphere Percus-Yevick structure + factor evaluated at effective radius \(R\) and packing term \(\phi\). +- `scale` and `offset` are the global multiplicative and additive terms exposed + in the Prefit parameter table. + +## MonoSQ Hard-Sphere Family + +Applies to: + +- `template_pydream_monosq_normalized.py` +- `template_likelihood_monosq.py` +- `template_pd_likelihood_monosq.py` +- `template_pd_likelihood_monosq_decoupled.py` + +These templates treat the MD-derived component profiles as a weighted solute +mixture modulated by a **single** monodisperse hard-sphere structure factor. + +\[ +I*{\mathrm{mix}}(q) = \sum*{i=0}^{N-1} w_i I_i(q) +\] + +\[ +I*{\mathrm{model}}(q) = +\mathrm{scale}\, I*{\mathrm{mix}}(q)\, +S*{\mathrm{HS}}(q; R*{\mathrm{eff}}, \phi\_{\mathrm{vol}}) + +- w*{\mathrm{solv}} I*{\mathrm{solv}}(q) +- \mathrm{offset} + \] + +### Variables + +| Symbol / parameter | Meaning in SAXSShell | +| ------------------------------------- | ------------------------------------------------------------------- | +| \(w_i\) | generated component weight for cluster profile \(i\) | +| \(w\_{\mathrm{solv}}\) / `solv_w` | bounded solvent contribution weight | +| \(R\_{\mathrm{eff}}\) / `eff_r` | effective hard-sphere radius used in `calc_monodisperse_sq(...)` | +| \(\phi\_{\mathrm{vol}}\) / `vol_frac` | effective hard-sphere volume fraction inside the Percus-Yevick term | +| `scale` | solute intensity scale factor | +| `offset` | constant additive background | + +### Likelihood conventions + +The current `pyDREAM MonoSQ Normalized` template uses a point-normalized +Gaussian log-likelihood with a fixed noise scale of \(10^{-4}\): + +\[ +\log \mathcal{L}_{\mathrm{norm}} = +\frac{1}{N_q} +\sum_{k=1}^{N*q} +\log \mathcal{N} +\left( +I*{\exp}(q*k)\ \middle|\ I*{\mathrm{model}}(q_k), 10^{-4} +\right) +\] + +The archived `MonoSQ Basic` template uses the same forward model but omits the +\(1/N_q\) normalization. The archived `MonoSQ Decoupled` template keeps the same +equation and simply factors the forward model into an intermediate helper +function before evaluating the likelihood. + +### Literature + +- J. K. Percus and G. J. Yevick, _Analysis of Classical Statistical Mechanics by Means of Collective Coordinates_, + Phys. Rev. **110**, 1-13 (1958). +- M. S. Wertheim, _Exact Solution of the Percus-Yevick Integral Equation for Hard Spheres_, + Phys. Rev. Lett. **10**, 321-323 (1963). +- J. S. Pedersen, _Analysis of small-angle scattering data from colloids and polymer solutions: modeling and least-squares fitting_, + Adv. Colloid Interface Sci. **70**, 171-210 (1997). + +## Poly LMA Hard-Sphere + +Applies to: + +- `template_pydream_poly_lma_hs.py` + +This template uses a **discrete local-monodisperse-approximation-style** +cluster sum: each cluster profile keeps its own effective interaction radius, +but the cluster abundances are normalized internally before evaluating the +solute mixture. + +\[ +x_i = \frac{w_i}{\sum_j w_j}, +\qquad +\sum_i x_i = 1 +\] + +\[ +I*{\mathrm{model}}(q) = +\mathrm{scale}\,\phi*{\mathrm{solute}} +\sum*{i=0}^{N-1} +x_i I_i(q) S*{\mathrm{HS}}(q; R*i^{\mathrm{eff}}, \phi*{\mathrm{int}}) + +- s*{\mathrm{solv}} (1-\phi*{\mathrm{solute}}) I\_{\mathrm{solv}}(q) +- \mathrm{offset} + \] + +### Variables + +| Symbol / parameter | Meaning in SAXSShell | +| ------------------------------------------ | ------------------------------------------------------------------------------------------------ | +| \(w_i\) | raw cluster-abundance coefficient generated from the project component rows | +| \(x_i\) | normalized abundance used internally by the model | +| \(R_i^{\mathrm{eff}}\) | per-cluster effective interaction radius | +| `r_eff_wN` | generated Prefit/DREAM radius parameter for cluster `wN` when sphere mode is active | +| \(\phi\_{\mathrm{solute}}\) / `phi_solute` | SAXS-effective solute interaction ratio scaling the cluster contribution | +| \(\phi\_{\mathrm{int}}\) / `phi_int` | interaction packing fraction used only inside the hard-sphere structure factor | +| \(s\_{\mathrm{solv}}\) / `solvent_scale` | bounded attenuation solvent-scaling term, used together with the `phi_solute` solvent complement | +| `scale` | solute intensity scale factor | +| `offset` | constant additive background | +| \(\sigma = e^{\log \sigma}\) / `log_sigma` | Gaussian noise scale for the DREAM likelihood | + +In the current implementation, \(R_i^{\mathrm{eff}}\) is taken from the +generated parameter `r_eff_wN` when that row exists. Otherwise, the template +falls back to the cluster-geometry metadata value supplied by Prefit. + +### Likelihood convention + +\[ +\log \mathcal{L}_{\mathrm{norm}} = +\frac{1}{N_q} +\sum_{k=1}^{N*q} +\log \mathcal{N} +\left( +I*{\exp}(q*k)\ \middle|\ I*{\mathrm{model}}(q_k), e^{\log \sigma} +\right) +\] + +### Literature + +- J. S. Pedersen, _Determination of size distributions from small-angle scattering data for systems with effective hard-sphere interactions_, + J. Appl. Cryst. **27**, 595-608 (1994). +- J. S. Pedersen, _Analysis of small-angle scattering data from colloids and polymer solutions: modeling and least-squares fitting_, + Adv. Colloid Interface Sci. **70**, 171-210 (1997). +- J. K. Percus and G. J. Yevick, _Analysis of Classical Statistical Mechanics by Means of Collective Coordinates_, + Phys. Rev. **110**, 1-13 (1958). +- M. S. Wertheim, _Exact Solution of the Percus-Yevick Integral Equation for Hard Spheres_, + Phys. Rev. Lett. **10**, 321-323 (1963). + +## Poly LMA Hard-Sphere/Ellipsoid Mix (Approx.) + +Applies to: + +- `template_pydream_poly_lma_hs_mix_approx.py` +- `template_pydream_poly_lma_hs_legacy.py` + +This template keeps the same cluster-summed hard-sphere equation as the +sphere-only Poly LMA model, but it allows Prefit geometry rows to be toggled +between sphere and ellipsoid approximations. + +\[ +I*{\mathrm{model}}(q) = +\mathrm{scale}\,\phi*{\mathrm{solute}} +\sum*{i=0}^{N-1} +x_i I_i(q) S*{\mathrm{HS}}(q; R*i^{\mathrm{eff}}, \phi*{\mathrm{int}}) + +- s*{\mathrm{solv}} (1-\phi*{\mathrm{solute}}) I\_{\mathrm{solv}}(q) +- \mathrm{offset} + \] + +The difference is how the effective interaction radius is resolved: + +\[ +R*i^{\mathrm{eff}} = +\begin{cases} +r*{\mathrm{eff},i}, & \text{if the component is treated as a sphere} \\ +\left(a_i b_i c_i\right)^{1/3}, & \text{if the component is treated as an ellipsoid} +\end{cases} +\] + +Here \(a_i\), \(b_i\), and \(c_i\) correspond to the generated semiaxis +parameters `a_eff_wN`, `b_eff_wN`, and `c_eff_wN`. + +!!! warning "Approximation Scope" +This is a SAXSShell approximation, not an exact hard-ellipsoid +Percus-Yevick closure. Ellipsoid geometry is reduced to an +equivalent-volume sphere before the hard-sphere structure factor is +evaluated. + +### Variables + +The weight, solvent, `scale`, `offset`, `phi_solute`, `phi_int`, and +`log_sigma` terms are the same as in the sphere-only Poly LMA model. The extra +geometry-dependent parameters are: + +| Symbol / parameter | Meaning in SAXSShell | +| ---------------------------------- | ---------------------------------------------------------------------------------------- | +| `r_eff_wN` | sphere radius parameter when the mapped component uses the sphere approximation | +| `a_eff_wN`, `b_eff_wN`, `c_eff_wN` | ellipsoid semiaxis parameters when the mapped component uses the ellipsoid approximation | +| \(R_i^{\mathrm{eff}}\) | effective radius actually passed into the hard-sphere structure factor | + +### Literature + +- S. Hansen, _Monte Carlo estimation of the structure factor for hard bodies in small-angle scattering_, + J. Appl. Cryst. **45**, 381-388 (2012). +- S. Hansen, _Approximation of the structure factor for nonspherical hard bodies using polydisperse spheres_, + J. Appl. Cryst. **46**, 1008-1016 (2013). +- J. S. Pedersen, _Determination of size distributions from small-angle scattering data for systems with effective hard-sphere interactions_, + J. Appl. Cryst. **27**, 595-608 (1994). + +## Archived Template Notes + +The archived templates are still loadable for older projects, but they map onto +the current model families as follows: + +- `template_likelihood_monosq.py`: same MonoSQ forward model, legacy + unnormalized Gaussian log-likelihood. +- `template_pd_likelihood_monosq.py`: same MonoSQ forward model, normalized + Gaussian log-likelihood. +- `template_pd_likelihood_monosq_decoupled.py`: same MonoSQ forward model, + normalized likelihood, explicit `model_monosq(...)` helper. +- `template_pydream_poly_lma_hs_legacy.py`: compatibility wrapper around the + current mixed-shape approximate Poly LMA model. diff --git a/docs/user-guide/pydream-workflow.md b/docs/user-guide/pydream-workflow.md index 8d78958..bce8743 100644 --- a/docs/user-guide/pydream-workflow.md +++ b/docs/user-guide/pydream-workflow.md @@ -186,20 +186,20 @@ loc' = c - scale' / 2 So the preset changes spread, not the intended center of the prior. -### Apply To: All Structures vs Selected Structures +### Apply To: All Parameters vs Selected Parameters For the single-mode presets above, the **Apply to** control determines whether SAXSShell adjusts: -- all structure groups in the table -- or only the currently selected structure rows +- all parameter rows in the table +- or only the currently selected parameter rows In the current implementation, a "structure group" means: - all rows sharing the same `(structure, motif)` pair, if those fields are set - otherwise only the specific row itself -That means if you apply `Strict` to one selected cluster structure, SAXSShell +That means if you apply `Strict` to one selected cluster parameter row, SAXSShell updates the whole structure/motif group together so its associated weight and related rows stay synchronized. @@ -210,7 +210,7 @@ The two mixed presets are: - **Strict Small / Lenient Large** - **Lenient Small / Strict Large** -These always apply across **all** structures because SAXSShell must rank the +These always apply across **all** parameters because SAXSShell must rank the structures against each other before deciding which ones are "small" or "large". diff --git a/docs/user-guide/saxs-prefit.md b/docs/user-guide/saxs-prefit.md index 8bb4000..6299d87 100644 --- a/docs/user-guide/saxs-prefit.md +++ b/docs/user-guide/saxs-prefit.md @@ -20,8 +20,20 @@ From the current UI implementation, Prefit supports: - exporting plot data - saving and restoring Prefit state - saving a **Best Prefit** parameter preset -- using a solution-based solute volume-fraction estimator when the active - template exposes a solute or solvent fraction parameter +- using embedded solution-scattering estimators for solute volume fraction, + solvent attenuation scaling, and fluorescence-background screening + +The parameter table also supports lightweight Artemis-style expression modes in +the **Value** column: + +- if **Value** is a math expression and **Vary** is enabled, Prefit treats the + expression as a `guess`-style initial-value seed. The expression is resolved + into the current numeric starting value, but the parameter can still refine + inside **Min** / **Max**. +- if **Value** is a math expression and **Vary** is disabled, Prefit treats the + expression as a `def`-style dependent parameter. In that mode the parameter + follows the expression during evaluation and fitting, and its own **Min** / + **Max** are ignored. ## Cluster geometry metadata @@ -43,26 +55,215 @@ Recent updates in the codebase also include: - positive-radius validation with explicit error reporting - template-aware restrictions on which shape approximations are allowed -## Solute volume-fraction estimator +## Solution scattering estimators + +Prefit now includes an embedded **Solution Scattering Estimators** section, and +the same calculations are also available from the **Tools** menu as: + +- **Open Volume Fraction Estimate** +- **Open Attenuation Estimate** +- **Open Fluorescence Estimate** + +All three calculators share the same composition inputs: + +- solution density +- solute and solvent stoichiometries +- component molar masses +- component densities when the selected input mode requires them +- beam energy +- capillary size and geometry +- beam footprint and beam profile + +The current implementation assumes a centered beam footprint and a uniform beam +profile. Flat-plate geometry uses a constant path length, while cylindrical +geometry averages across the illuminated chord lengths of the capillary. + +### Software and data sources + +The solution-composition bookkeeping is handled by SAXSShell's internal +solution-property helpers. Attenuation and fluorescence quantities are then +estimated from empirical formulas using `xraydb`, which exposes Elam-style +atomic edge, line, and fluorescence-yield data together with NIST-style mass +attenuation calculations. -If the active template includes a solute or solvent fraction parameter such as -`phi_solute`, Prefit can show an embedded **Solute Volume Fraction Estimator** -between the main controls and the cluster-geometry table. +In practical terms, SAXSShell combines: -The current implementation uses the measured solution volume together with the -input solute density: +- solution masses, densities, and stoichiometries from the Prefit widget +- linear attenuation coefficients derived from empirical formulas +- edge-resolved fluorescence yields and line families from `xraydb` +- beam-path averages defined by the selected capillary geometry -- `c_solute = m_solute / V_solution` -- `vbar_solute ~= 1 / rho_solute` -- `phi_solute ~= c_solute * vbar_solute` +### Physical solute-associated volume fraction + +For the physical solute-associated volume-fraction estimate, SAXSShell uses the +measured solution volume together with the solute density: + +$$ +c_{\mathrm{solute}} = \frac{m_{\mathrm{solute}}}{V_{\mathrm{solution}}}, +\qquad +\bar{v}_{\mathrm{solute}} \approx \frac{1}{\rho_{\mathrm{solute}}}, +\qquad +\phi_{\mathrm{phys}} \approx c_{\mathrm{solute}} \bar{v}_{\mathrm{solute}}. +$$ + +The solvent fraction is then reported as: + +$$ +\phi_{\mathrm{solvent,phys}} = 1 - \phi_{\mathrm{phys}}. +$$ This is closer to the concentration-plus-specific-volume logic commonly used in solution SAXS than the older additive-volume estimate -`V_solute / (V_solute + V_solvent)`. - -The tool still reports additive component volumes as a diagnostic, but the -value written back into the Prefit parameter table is now the measured-solution -estimate above. +$V_{\mathrm{solute}} / (V_{\mathrm{solute}} + V_{\mathrm{solvent}})$. + +SAXSShell still prints this physical estimate in the output console for +reference, but it is no longer written directly into the model-facing +`phi_solute` / `phi_solvent` defaults. + +### SAXS-effective interaction contrast ratio + +The model-facing solute fraction now uses an energy-dependent +contrast-weighted interaction ratio. SAXSShell forms an effective forward +scattering-electron density proxy from the component formula, density, and the +real anomalous correction \(f'(E)\): + +$$ +\rho_{\mathrm{eff}}(E) += +\rho_{\mathrm{mass}} +\frac{N_A}{M} +\sum_i n_i \left[ Z_i + f'_i(E) \right]. +$$ + +Using the solute-solvent contrast +\(\Delta \rho*{\mathrm{eff}}(E) = \rho*{\mathrm{eff,solute}}(E) - \rho\_{\mathrm{eff,solvent}}(E)\), +SAXSShell defines a contrast-weight factor + +$$ +C(E) += +\left( +\frac{\Delta \rho_{\mathrm{eff}}(E)} +{\rho_{\mathrm{eff,solvent}}(E)} +\right)^2, +$$ + +an effective solute interaction volume + +$$ +V_{\mathrm{solute}}^{\mathrm{eff}}(E) += +C(E) \, V_{\mathrm{solute,phys}}, +$$ + +and the model-facing SAXS ratio + +$$ +R_{\mathrm{saxs}}(E) += +\frac{V_{\mathrm{solute}}^{\mathrm{eff}}(E)} +{V_{\mathrm{solute}}^{\mathrm{eff}}(E) + V_{\mathrm{solvent,phys}}}. +$$ + +This keeps the physical occupancy estimate visible, but it lets the model +scale the solute and solvent terms using a contrast-sensitive ratio at the +selected beam energy. If the active template exposes `phi_solute` or +`phi_solvent`, Prefit now writes `R_saxs(E)` or its complement into that +parameter and sets `vary = off`. + +### Attenuation and solvent scattering scale + +For attenuation, SAXSShell forms the sample and neat-solvent linear attenuation +coefficients from concentration-weighted mass attenuation coefficients: + +$$ +\mu_{\mathrm{sample}}(E) += c_{\mathrm{solute}} \left(\frac{\mu}{\rho}\right)_{\mathrm{solute}}(E) ++ c_{\mathrm{solvent}} \left(\frac{\mu}{\rho}\right)_{\mathrm{solvent}}(E), +$$ + +$$ +\mu_{\mathrm{neat}}(E) += \rho_{\mathrm{solvent}} +\left(\frac{\mu}{\rho}\right)_{\mathrm{solvent}}(E). +$$ + +For a path length $L$, the transmission model is: + +$$ +T(E, L) = e^{-\mu(E)L}. +$$ + +To estimate how much the neat-solvent reference should be scaled down before it +represents the solvent fraction inside the sample, SAXSShell compares +beam-profile-averaged scattering weights: + +$$ +w_{\mathrm{solv}} += +\frac{ +c_{\mathrm{solvent}} +\left\langle L e^{-\mu_{\mathrm{sample}} L} \right\rangle +}{ +\rho_{\mathrm{solvent}} +\left\langle L e^{-\mu_{\mathrm{neat}} L} \right\rangle +}. +$$ + +Here $\langle \cdots \rangle$ denotes the path-length average across the +illuminated capillary cross section. This produces a solvent contribution scale +factor that answers the practical SAXS question, "how much more scattering +intensity does the neat solvent have than the solvent fraction inside the real +sample?" + +If the active template includes both a model-facing fraction parameter +(`phi_solute` / `phi_solvent`) and a solvent-weight parameter such as +`solvent_scale`, Prefit writes the attenuation factor above into +`solvent_scale` and uses `R_saxs(E)` for the fraction parameter. The solvent +term therefore becomes +\(w*{\mathrm{solv}} (1 - R*{\mathrm{saxs}}) I\_{\mathrm{solv}}(q)\). + +If the template only exposes a single solvent-weight parameter such as +`solv_w`, Prefit writes the combined solvent-background multiplier + +$$ +w_{\mathrm{model}} = \left(1 - R_{\mathrm{saxs}}(E)\right) w_{\mathrm{solv}} +$$ + +into that parameter directly. + +### Fluorescence background proxy + +The fluorescence estimator is intentionally a screening calculation rather than +a full transport simulation. It starts from the sample photoelectric +attenuation at the incident energy $E_0$, partitions that absorption across +accessible edges using edge jump ratios, and then applies fluorescence yields +and line-family branching: + +$$ +Y^{(1)}_{e,\ell} +\propto +\mu^{\mathrm{photo}}_{e,\mathrm{edge}}(E_0)\, +\omega_{e,\mathrm{edge}}\, +p_{e,\ell}\, +\left\langle +\mathcal{I}(L; \mu_{\mathrm{in}}, \mu_{\mathrm{out}}) +\right\rangle. +$$ + +In this expression: + +- $e$ is the emitting element +- $\ell$ is the emitted line family +- $\omega$ is the fluorescence yield +- $p$ is the line branching probability +- $\mathcal{I}$ is the path-integrated incident/escape attenuation term + +SAXSShell then adds a first-order secondary-fluorescence pass by allowing the +primary fluorescent photons to be reabsorbed once and re-emitted before +escaping the sample. This is useful for ranking samples by expected +fluorescence-background severity, but it is not a Monte Carlo X-ray transport +calculation and should not be treated as a detector-absolute prediction. ## Plot controls @@ -169,11 +370,33 @@ exists and is mapped correctly. - If a geometry-aware template refuses to update, check the mapping column and the active radii values before looking elsewhere. +## TODO + +TODO: re-check the consistency between the attenuation estimator output +(`solvent_scale` / `solv_w`) and the solvent contribution used by the model. +At least some current workflows appear to produce model-facing solvent weights +around `0.15` even when the imported solvent trace still looks about `100x` +too large by eye at the same q-position maxima. Revisit whether this comes +from a mismatch between: + +- the attenuation-only estimate and the actual model solvent term +- split-fraction versus single-solvent-weight template conventions +- solvent blank versus sample normalization, transmission, thickness, or exposure +- partially pre-subtracted or otherwise inconsistently reduced solvent traces +- the possible need for a separate empirical solvent-normalization factor in + addition to the attenuation scaling + ## References - [SasView sphere model: absolute-scale interpretation of `scale`, explicit contrast term, and flat `background`.](https://www.sasview.org/docs/user/models/sphere.html) +- [Pedersen JS. _Analysis of small-angle scattering data from colloids and polymer solutions: modeling and least-squares fitting_. Advances in Colloid and Interface Science (1997).]() - [SasView power-law model: example of a flat additive background term.](https://www.sasview.org/docs/user/models/power_law.html) - [Schneidman-Duhovny D, Hammel M, Tainer JA, Sali A. _Accurate SAXS profile computation and its assessment by contrast variation experiments_. Biophysical Journal (2013).](https://pubmed.ncbi.nlm.nih.gov/23972848/) - [Henriques J, Arleth L, Lindorff-Larsen K, Skepö M. _On the Calculation of SAXS Profiles of Folded and Intrinsically Disordered Proteins from Computer Simulations_. Journal of Molecular Biology (2018).](https://pubmed.ncbi.nlm.nih.gov/29548755/) - [Edwards-Gayle CJC, Khunti N, Hamley IW, Inoue K, Cowieson N, Rambo RP. _Design of a multipurpose sample cell holder for the Diamond Light Source high-throughput SAXS beamline B21_. Journal of Synchrotron Radiation (2021).](https://pmc.ncbi.nlm.nih.gov/articles/PMC7842227/) - [Hajizadeh N, Franke D, Jeffries CM, Svergun DI. _Consensus Bayesian assessment of protein molecular mass from solution X-ray scattering data_. Scientific Reports (2018).](https://www.nature.com/articles/s41598-018-25355-2) +- [Hubbell JH. _Photon Cross Sections, Attenuation Coefficients, and Energy Absorption Coefficients from 10 keV to 100 GeV_. NBS NSRDS 29.](https://doi.org/10.6028/NBS.NSRDS.29) +- [XrayDB Python reference: attenuation, edge, and fluorescence-yield APIs used by SAXSShell.](https://scikit-beam.github.io/XrayDB/python.html) +- [Elam WT, Ravel BD, Sieber JR. _A new atomic database for X-ray spectroscopic calculations_. Radiation Physics and Chemistry (2002).]() +- [Roter B, et al. Discussion of edge jump ratios and fluorescence forward modeling.](https://pmc.ncbi.nlm.nih.gov/articles/PMC12871215/) +- [Trevorah RM, et al. Discussion of self-absorption and fluorescence re-absorption corrections.](https://pmc.ncbi.nlm.nih.gov/articles/PMC6608621/) diff --git a/docs/user-guide/template-system.md b/docs/user-guide/template-system.md index b477506..18f15ce 100644 --- a/docs/user-guide/template-system.md +++ b/docs/user-guide/template-system.md @@ -134,6 +134,9 @@ The repository currently includes bundled templates such as: - poly-LMA hard-sphere workflows - approximate mixed sphere/ellipsoid workflows +For the bundled model equations, variable definitions, and literature links, +see [Pre-loaded SAXS Models](preloaded-saxs-models.md). + Some older templates now live in a `_deprecated` subfolder. They are hidden by default in template dropdowns, but older projects can still load them. @@ -327,6 +330,18 @@ This is especially relevant for structure-derived scattering workflows, where published SAXS studies show that uncertainties in hydration-layer contrast can materially change the calculated profile. +For the current poly-LMA solvent-subtraction workflows, the repository also +distinguishes between: + +- a physical bulk-density solute-associated volume fraction reported for reference +- a SAXS-effective interaction ratio used for the model-facing `phi_solute` / + `phi_solvent` default + +That distinction matters because the solvent-background subtraction is carried +by both the solute/solvent split and the explicit solvent term. The model-facing +split therefore follows the contrast-weighted SAXS interaction estimate, while +the attenuation term stays in `solvent_scale`. + Likewise, `offset` is included because a flat residual background is common in real SAXS data reduction. In practice this can represent imperfect background subtraction or residual background from the sample environment, including diff --git a/mkdocs.yml b/mkdocs.yml index 65cd696..6bf6c72 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -41,6 +41,8 @@ markdown_extensions: - def_list - footnotes - tables + - pymdownx.arithmatex: + generic: true - toc: permalink: true - pymdownx.details @@ -51,6 +53,10 @@ markdown_extensions: - pymdownx.tabbed: alternate_style: true +extra_javascript: + - javascripts/mathjax.js + - https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js + nav: - Home: index.md - Getting Started: @@ -61,7 +67,9 @@ nav: - GUI Overview: user-guide/gui-overview.md - Project Configuration: user-guide/project-configuration.md - Cluster Extraction: user-guide/cluster-extraction.md + - Cluster Dynamics: user-guide/cluster-dynamics.md - SAXS Prefit: user-guide/saxs-prefit.md + - Pre-loaded SAXS Models: user-guide/preloaded-saxs-models.md - LMFit Workflow: user-guide/lmfit-workflow.md - pyDREAM Workflow: user-guide/pydream-workflow.md - Results and Export: user-guide/results-and-export.md diff --git a/pyproject.toml b/pyproject.toml index 25dbf94..8e082f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ Issues = "https://github.com/kewh5868/SAXSShell/issues/" [project.scripts] bondanalysis = "saxshell.bondanalysis.cli:main" +clusterdynamics = "saxshell.clusterdynamics.cli:main" +clusterdynamicsml = "saxshell.clusterdynamicsml.cli:main" clusters = "saxshell.cluster.cli:main" mdtrajectory = "saxshell.mdtrajectory.cli:main" saxshell = "saxshell.saxshell:main" @@ -63,6 +65,7 @@ dependencies = { file = ["requirements/pip.txt"] } [tool.setuptools.package-data] "saxshell.xyz2pdb" = ["reference_library/*.pdb", "reference_library/*.txt"] "saxshell.fullrmc" = ["_solution_property_presets/*.json"] +"saxshell.saxs" = ["_beam_geometry_presets/*.json", "_ui_assets/*.svg"] [tool.codespell] exclude-file = ".codespell/ignore_lines.txt" diff --git a/requirements/pip.txt b/requirements/pip.txt index 0a64651..41aa96c 100644 --- a/requirements/pip.txt +++ b/requirements/pip.txt @@ -3,5 +3,6 @@ lmfit>=1.3.4 numpy>=2.4.2 pydream PySide6>=6.6.0 +python-pptx>=0.6.23 scipy>=1.16.2 xraydb>=4.5.8 diff --git a/src/saxshell/bondanalysis/ui/main_window.py b/src/saxshell/bondanalysis/ui/main_window.py index c95b267..f9a24c1 100644 --- a/src/saxshell/bondanalysis/ui/main_window.py +++ b/src/saxshell/bondanalysis/ui/main_window.py @@ -55,6 +55,11 @@ load_result_index, ) from saxshell.bondanalysis.ui.plot_window import BondAnalysisPlotWindow +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) _OPEN_WINDOWS: list["BondAnalysisMainWindow"] = [] @@ -115,6 +120,7 @@ def __init__( def _build_ui(self) -> None: self.setWindowTitle("SAXSShell (bondanalysis)") + self.setWindowIcon(load_saxshell_icon()) self.resize(1320, 840) central = QWidget() @@ -1250,7 +1256,9 @@ def launch_bondanalysis_ui( app = QApplication.instance() owns_app = app is None if app is None: + prepare_saxshell_application_identity() app = QApplication(sys.argv) + configure_saxshell_application(app) window = BondAnalysisMainWindow(initial_clusters_dir=clusters_dir) _OPEN_WINDOWS.append(window) diff --git a/src/saxshell/cluster/ui/definitions_panel.py b/src/saxshell/cluster/ui/definitions_panel.py index 1869664..54881d4 100644 --- a/src/saxshell/cluster/ui/definitions_panel.py +++ b/src/saxshell/cluster/ui/definitions_panel.py @@ -518,18 +518,130 @@ def default_cutoff(self) -> float | None: value = self.default_cutoff_spin.value() return None if value <= 0.0 else value + def load_atom_type_definitions( + self, + definitions: AtomTypeDefinitions, + *, + emit_signal: bool = True, + ) -> None: + self.atom_table.setRowCount(0) + ordered_atom_types = ["node", "linker", "shell"] + seen_atom_types: set[str] = set() + for atom_type in ordered_atom_types + sorted(definitions): + if atom_type in seen_atom_types: + continue + seen_atom_types.add(atom_type) + for element, residue in definitions.get(atom_type, []): + row = self.atom_table.rowCount() + self.atom_table.insertRow(row) + type_combo = QComboBox() + type_combo.addItems(["node", "linker", "shell"]) + if atom_type not in {"node", "linker", "shell"}: + type_combo.addItem(atom_type) + type_combo.setCurrentText(atom_type) + type_combo.setToolTip("Cluster role assigned to this element.") + type_combo.currentTextChanged.connect( + lambda _text: self.settings_changed.emit() + ) + self.atom_table.setCellWidget(row, 0, type_combo) + self.atom_table.setItem( + row, + 1, + QTableWidgetItem("" if element is None else str(element)), + ) + self.atom_table.setItem( + row, + 2, + QTableWidgetItem("" if residue is None else str(residue)), + ) + self._sync_pair_element_choices() + if emit_signal: + self.settings_changed.emit() + + def load_pair_cutoff_definitions( + self, + definitions: PairCutoffDefinitions, + *, + emit_signal: bool = True, + ) -> None: + self.pair_table.setRowCount(0) + self._sync_pair_element_choices() + for atom1, atom2 in sorted(definitions): + row = self.pair_table.rowCount() + self.pair_table.insertRow(row) + atom1_combo = self._make_pair_combo() + atom2_combo = self._make_pair_combo() + self.pair_table.setCellWidget(row, 0, atom1_combo) + self.pair_table.setCellWidget(row, 1, atom2_combo) + self._sync_pair_element_choices() + atom1_combo.setCurrentText(atom1) + atom2_combo.setCurrentText(atom2) + shell_cutoffs = definitions[(atom1, atom2)] + for level, column in enumerate((2, 3, 4)): + cutoff = shell_cutoffs.get(level) + self.pair_table.setItem( + row, + column, + QTableWidgetItem("" if cutoff is None else str(cutoff)), + ) + if emit_signal: + self.settings_changed.emit() + + def set_default_cutoff( + self, + value: float | None, + *, + emit_signal: bool = True, + ) -> None: + self.default_cutoff_spin.blockSignals(True) + self.default_cutoff_spin.setValue( + 0.0 if value is None else float(value) + ) + self.default_cutoff_spin.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + def use_pbc(self) -> bool: return self.use_pbc_box.isChecked() + def set_use_pbc(self, value: bool, *, emit_signal: bool = True) -> None: + self.use_pbc_box.blockSignals(True) + self.use_pbc_box.setChecked(bool(value)) + self.use_pbc_box.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + def search_mode(self) -> str: data = self.search_mode_combo.currentData() if data is None: return SEARCH_MODE_KDTREE return str(data) + def set_search_mode(self, value: str, *, emit_signal: bool = True) -> None: + for index in range(self.search_mode_combo.count()): + if self.search_mode_combo.itemData(index) == value: + self.search_mode_combo.blockSignals(True) + self.search_mode_combo.setCurrentIndex(index) + self.search_mode_combo.blockSignals(False) + break + if emit_signal: + self.settings_changed.emit() + def save_state_frequency(self) -> int: return int(self.save_state_frequency_spin.value()) + def set_save_state_frequency( + self, + value: int, + *, + emit_signal: bool = True, + ) -> None: + self.save_state_frequency_spin.blockSignals(True) + self.save_state_frequency_spin.setValue(int(value)) + self.save_state_frequency_spin.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + def include_shell_levels(self) -> tuple[int, ...]: levels = [0] if self.shell1_box.isChecked(): @@ -543,12 +655,52 @@ def shell_growth_levels(self) -> tuple[int, ...]: level for level in self.include_shell_levels() if level > 0 ) + def set_shell_growth_levels( + self, + levels: tuple[int, ...] | list[int], + *, + emit_signal: bool = True, + ) -> None: + normalized = {int(level) for level in levels} + self.shell1_box.blockSignals(True) + self.shell2_box.blockSignals(True) + self.shell1_box.setChecked(1 in normalized) + self.shell2_box.setChecked(2 in normalized) + self.shell1_box.blockSignals(False) + self.shell2_box.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + def shared_shells(self) -> bool: return self.shared_shells_box.isChecked() + def set_shared_shells( + self, + value: bool, + *, + emit_signal: bool = True, + ) -> None: + self.shared_shells_box.blockSignals(True) + self.shared_shells_box.setChecked(bool(value)) + self.shared_shells_box.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + def include_shell_atoms_in_stoichiometry(self) -> bool: return self.include_shell_stoichiometry_box.isChecked() + def set_include_shell_atoms_in_stoichiometry( + self, + value: bool, + *, + emit_signal: bool = True, + ) -> None: + self.include_shell_stoichiometry_box.blockSignals(True) + self.include_shell_stoichiometry_box.setChecked(bool(value)) + self.include_shell_stoichiometry_box.blockSignals(False) + if emit_signal: + self.settings_changed.emit() + def rule_counts(self) -> tuple[int, int]: atom_rules = sum( len(criteria) for criteria in self.atom_type_definitions().values() diff --git a/src/saxshell/cluster/ui/main_window.py b/src/saxshell/cluster/ui/main_window.py index f7da746..e5754c2 100644 --- a/src/saxshell/cluster/ui/main_window.py +++ b/src/saxshell/cluster/ui/main_window.py @@ -34,6 +34,11 @@ from saxshell.cluster.ui.definitions_panel import ClusterDefinitionsPanel from saxshell.cluster.ui.export_panel import ClusterExportPanel from saxshell.cluster.ui.trajectory_panel import ClusterTrajectoryPanel +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) from saxshell.structure import AtomTypeDefinitions @@ -378,6 +383,7 @@ def __init__(self, initial_frames_dir: Path | None = None) -> None: def _build_ui(self) -> None: self.setWindowTitle("SAXSShell (cluster)") + self.setWindowIcon(load_saxshell_icon()) self.resize(1360, 860) central = QWidget() @@ -1153,7 +1159,9 @@ def launch_cluster_ui( app = QApplication.instance() owns_app = app is None if app is None: + prepare_saxshell_application_identity() app = QApplication(sys.argv) + configure_saxshell_application(app) window = ClusterMainWindow( initial_frames_dir=Path(frames_dir) if frames_dir is not None else None diff --git a/src/saxshell/fullrmc/__init__.py b/src/saxshell/fullrmc/__init__.py index a400971..1610bc1 100644 --- a/src/saxshell/fullrmc/__init__.py +++ b/src/saxshell/fullrmc/__init__.py @@ -1,5 +1,7 @@ """Fullrmc setup scaffolding and launch helpers.""" +from typing import TYPE_CHECKING + from .constraint_generation import ( ConstraintGenerationEntry, ConstraintGenerationMetadata, @@ -84,8 +86,10 @@ load_solvent_handling_metadata, save_solvent_handling_metadata, ) -from .ui.main_window import RMCSetupMainWindow, launch_rmcsetup_ui -from .ui.representative_preview_window import RepresentativePreviewWindow + +if TYPE_CHECKING: + from .ui.main_window import RMCSetupMainWindow, launch_rmcsetup_ui + from .ui.representative_preview_window import RepresentativePreviewWindow __all__ = [ "ClusterSourceValidationResult", @@ -158,3 +162,21 @@ "solution_property_presets_path", "validate_cluster_source", ] + + +def __getattr__(name: str): + if name in {"RMCSetupMainWindow", "launch_rmcsetup_ui"}: + from .ui.main_window import RMCSetupMainWindow, launch_rmcsetup_ui + + exports = { + "RMCSetupMainWindow": RMCSetupMainWindow, + "launch_rmcsetup_ui": launch_rmcsetup_ui, + } + return exports[name] + if name == "RepresentativePreviewWindow": + from .ui.representative_preview_window import ( + RepresentativePreviewWindow, + ) + + return RepresentativePreviewWindow + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/saxshell/fullrmc/ui/main_window.py b/src/saxshell/fullrmc/ui/main_window.py index 544bb41..b7d5d72 100644 --- a/src/saxshell/fullrmc/ui/main_window.py +++ b/src/saxshell/fullrmc/ui/main_window.py @@ -120,6 +120,11 @@ format_stoich_for_axis, sort_stoich_labels, ) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) _OPEN_WINDOWS: list["RMCSetupMainWindow"] = [] @@ -377,6 +382,7 @@ def __init__( ) -> None: super().__init__() self.setWindowTitle("SAXSShell (rmcsetup)") + self.setWindowIcon(load_saxshell_icon()) self.resize(1080, 860) self.project_manager = SAXSProjectManager() @@ -4742,7 +4748,9 @@ def launch_rmcsetup_ui( app = QApplication.instance() owns_app = app is None if app is None: + prepare_saxshell_application_identity() app = QApplication(sys.argv) + configure_saxshell_application(app) window = RMCSetupMainWindow(initial_project_dir=project_dir) _OPEN_WINDOWS.append(window) diff --git a/src/saxshell/mdtrajectory/cli.py b/src/saxshell/mdtrajectory/cli.py index 7f668bb..2ac9b2a 100644 --- a/src/saxshell/mdtrajectory/cli.py +++ b/src/saxshell/mdtrajectory/cli.py @@ -185,11 +185,14 @@ def _add_cutoff_resolution_arguments( def _handle_ui(_: argparse.Namespace) -> int: from PySide6.QtWidgets import QApplication + from saxshell.saxs.ui.branding import prepare_saxshell_application_identity + from .ui.main_window import launch_mdtrajectory_app app = QApplication.instance() created_app = app is None if app is None: + prepare_saxshell_application_identity() app = QApplication(sys.argv) launch_mdtrajectory_app() if created_app: diff --git a/src/saxshell/mdtrajectory/ui/main_window.py b/src/saxshell/mdtrajectory/ui/main_window.py index e8de7aa..c060c7a 100644 --- a/src/saxshell/mdtrajectory/ui/main_window.py +++ b/src/saxshell/mdtrajectory/ui/main_window.py @@ -23,6 +23,11 @@ from saxshell.mdtrajectory.ui.state import MDTrajectoryAppState from saxshell.mdtrajectory.ui.trajectory_panel import TrajectoryPanel from saxshell.mdtrajectory.workflow import suggest_output_dir +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) @dataclass(slots=True) @@ -108,6 +113,7 @@ def __init__(self) -> None: def _build_ui(self) -> None: self.setWindowTitle("SAXSShell (mdtrajectory)") + self.setWindowIcon(load_saxshell_icon()) self.resize(1280, 780) central = QWidget() @@ -640,6 +646,11 @@ def _show_error(self, message: str) -> None: def launch_mdtrajectory_app() -> MDTrajectoryMainWindow: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication([]) + configure_saxshell_application(app) window = MDTrajectoryMainWindow() window.show() return window diff --git a/src/saxshell/mdtrajectory/workflow.py b/src/saxshell/mdtrajectory/workflow.py index 372e160..820af9e 100644 --- a/src/saxshell/mdtrajectory/workflow.py +++ b/src/saxshell/mdtrajectory/workflow.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from dataclasses import dataclass from pathlib import Path from re import sub @@ -14,6 +15,8 @@ TrajectoryManager, ) +EXPORT_METADATA_FILENAME = "mdtrajectory_export.json" + def suggest_output_dir( trajectory_file: str | Path, @@ -66,7 +69,7 @@ def _base_output_dir_name( return base_name cutoff_text = format_cutoff_for_dir(cutoff_fs) - return f"{base_name}_t{cutoff_text}fs" + return f"{base_name}_f{cutoff_text}fs" @dataclass(slots=True) @@ -103,12 +106,16 @@ class MDTrajectoryExportResult: output_dir: Path written_files: list[Path] selection: MDTrajectorySelectionResult + metadata_file: Path | None = None def to_dict(self) -> dict[str, object]: return { "output_dir": str(self.output_dir), "written_files": [str(path) for path in self.written_files], "written_count": len(self.written_files), + "metadata_file": ( + None if self.metadata_file is None else str(self.metadata_file) + ), "selection": self.selection.to_dict(), } @@ -259,6 +266,13 @@ def export_frames( if selection.preview.selected_frames == 0: raise ValueError("No frames match the current selection settings.") + selected_frames = self.manager.get_selected_frames( + start=start, + stop=stop, + stride=stride, + min_time_fs=selection.applied_cutoff_fs, + post_cutoff_stride=post_cutoff_stride, + ) written_files = self.manager.export_frames( output_dir=selection.output_dir, start=start, @@ -267,8 +281,53 @@ def export_frames( min_time_fs=selection.applied_cutoff_fs, post_cutoff_stride=post_cutoff_stride, ) + metadata_file = self._write_export_metadata( + selection=selection, + written_files=written_files, + selected_frames=selected_frames, + ) return MDTrajectoryExportResult( output_dir=selection.output_dir, written_files=written_files, selection=selection, + metadata_file=metadata_file, + ) + + def _write_export_metadata( + self, + *, + selection: MDTrajectorySelectionResult, + written_files: list[Path], + selected_frames, + ) -> Path: + metadata_path = selection.output_dir / EXPORT_METADATA_FILENAME + payload = { + "version": 1, + "trajectory_file": str(self.trajectory_file), + "topology_file": ( + None if self.topology_file is None else str(self.topology_file) + ), + "energy_file": ( + None if self.energy_file is None else str(self.energy_file) + ), + "selection": selection.to_dict(), + "written_frames": [ + { + "filename": path.name, + "frame_index": int(frame.frame_index), + "time_fs": ( + None if frame.time_fs is None else float(frame.time_fs) + ), + } + for frame, path in zip( + selected_frames, + written_files, + strict=False, + ) + ], + } + metadata_path.write_text( + json.dumps(payload, indent=2) + "\n", + encoding="utf-8", ) + return metadata_path diff --git a/src/saxshell/xyz2pdb/cli.py b/src/saxshell/xyz2pdb/cli.py index 3322bac..4489a15 100644 --- a/src/saxshell/xyz2pdb/cli.py +++ b/src/saxshell/xyz2pdb/cli.py @@ -184,11 +184,14 @@ def _add_common_conversion_arguments( def _handle_ui(args: argparse.Namespace) -> int: from PySide6.QtWidgets import QApplication + from saxshell.saxs.ui.branding import prepare_saxshell_application_identity + from .ui.main_window import launch_xyz2pdb_ui app = QApplication.instance() created_app = app is None if app is None: + prepare_saxshell_application_identity() app = QApplication(sys.argv) launch_xyz2pdb_ui( input_path=getattr(args, "input_path", None), diff --git a/src/saxshell/xyz2pdb/ui/main_window.py b/src/saxshell/xyz2pdb/ui/main_window.py index 30c95c8..3b96125 100644 --- a/src/saxshell/xyz2pdb/ui/main_window.py +++ b/src/saxshell/xyz2pdb/ui/main_window.py @@ -1,10 +1,12 @@ from __future__ import annotations +import sys from dataclasses import dataclass from pathlib import Path from PySide6.QtCore import QObject, Qt, QThread, Signal, Slot from PySide6.QtWidgets import ( + QApplication, QMainWindow, QMessageBox, QScrollArea, @@ -13,6 +15,11 @@ QWidget, ) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) from saxshell.xyz2pdb import ( XYZToPDBExportResult, XYZToPDBInspectionResult, @@ -107,6 +114,7 @@ def __init__( def _build_ui(self) -> None: self.setWindowTitle("SAXSShell (xyz2pdb)") + self.setWindowIcon(load_saxshell_icon()) self.resize(1360, 820) central = QWidget() @@ -438,6 +446,11 @@ def launch_xyz2pdb_ui( reference_library_dir: str | Path | None = None, ) -> XYZToPDBMainWindow: """Create, show, and return the xyz2pdb main window.""" + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) window = XYZToPDBMainWindow( input_path=None if input_path is None else Path(input_path), config_file=None if config_file is None else Path(config_file),