diff --git a/.codespellignore b/.codespellignore index d07210d..2357d38 100644 --- a/.codespellignore +++ b/.codespellignore @@ -1 +1,2 @@ Te +Nd diff --git a/README.md b/README.md index 5789376..e7e12c1 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,13 @@ repository conda environment file. ```bash git clone https://github.com/kewh5868/SAXSShell.git cd SAXSShell +``` + +### macOS, Linux, and WSL + +Create the Python 3.12 conda environment from the default environment file: + +```bash conda env create -f requirements/saxshell-py312.yml ``` @@ -55,6 +62,44 @@ Launch the main SAXSShell application from the repository root: PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ``` +### Native Windows + +On native Windows, use the Windows-specific environment file: + +```cmd +conda env create -f requirements\saxshell-py312-win.yml +``` + +If the `saxshell-py312` environment already exists, update it from the same +file: + +```cmd +conda env update -n saxshell-py312 -f requirements\saxshell-py312-win.yml --prune +``` + +From Anaconda Prompt, activate the environment, set `PYTHONPATH`, and launch +the SAXS UI: + +```cmd +conda activate saxshell-py312 +set PYTHONPATH=src +python -m saxshell.saxs +``` + +You can also launch without activating the environment: + +```cmd +set PYTHONPATH=src +conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs +``` + +From Windows PowerShell, set `PYTHONPATH` with PowerShell syntax: + +```powershell +$env:PYTHONPATH = "src" +conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs +``` + ## First Project Start by preparing the simulation data that the SAXS project will consume: diff --git a/README.rst b/README.rst index 356596b..79cfe61 100644 --- a/README.rst +++ b/README.rst @@ -60,13 +60,19 @@ Installation ------------ SAXSShell is not pip-installable yet. The current user-facing path is to clone -the repository and create the conda environment from the checked-in -``requirements/saxshell-py312.yml`` file. +the repository and create the conda environment from the checked-in environment +file for your platform. From a terminal, run :: git clone https://github.com/kewh5868/SAXSShell.git cd SAXSShell + +macOS, Linux, and WSL +~~~~~~~~~~~~~~~~~~~~~ + +Create the Python 3.12 conda environment from the default environment file :: + conda env create -f requirements/saxshell-py312.yml If the environment already exists, update it with :: @@ -77,6 +83,35 @@ Launch the main SAXSShell application from the repository root with :: PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs +Native Windows +~~~~~~~~~~~~~~ + +On native Windows, create the environment from the Windows-specific environment +file :: + + conda env create -f requirements\saxshell-py312-win.yml + +If the environment already exists, update it with :: + + conda env update -n saxshell-py312 -f requirements\saxshell-py312-win.yml --prune + +From Anaconda Prompt, activate the environment, set ``PYTHONPATH``, and launch +the SAXS UI :: + + conda activate saxshell-py312 + set PYTHONPATH=src + python -m saxshell.saxs + +You can also launch without activating the environment :: + + set PYTHONPATH=src + conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs + +From Windows PowerShell, set ``PYTHONPATH`` with PowerShell syntax :: + + $env:PYTHONPATH = "src" + conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs + You can also verify that the source checkout imports inside the conda environment with :: diff --git a/docs/getting-started/project-setup.md b/docs/getting-started/project-setup.md index dd0318e..5999404 100644 --- a/docs/getting-started/project-setup.md +++ b/docs/getting-started/project-setup.md @@ -108,14 +108,21 @@ The **Component build mode** dropdown controls what happens when you click Representative structures are optional project-backed files that compatible Debye, Born, FFT, and RMCSetup workflows can use instead of average cluster folders. Use **Tools > Structure Analysis > Open Representative Structures** for -the full interactive analysis UI, or use **Tools > (beta) > Open Representative -CLI Setup (Beta)** to save `representative_structure_cli_run.json` and run the -same backend from the source checkout: +the full interactive analysis UI, or use **Tools > CLI Setup > Open +Representative CLI Setup (Beta)** to save +`representative_structure_cli_run.json` and run the same backend from the +source checkout: ```bash PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.representativefinder run /path/to/project ``` +The same **Tools > CLI Setup** menu can prepare project-local run files for +XYZ-to-PDB conversion, cluster extraction, cluster dynamics, and cluster +dynamics ML. Those run files let you run long jobs from a terminal or batch +several prepared project folders while keeping outputs linked back to their +projects. + ## Debye-Waller factors **Compute Debye-Waller Factors (beta)** is an optional linked step in Project diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 7b1eaf6..15f37ff 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -18,8 +18,8 @@ that downstream tools should use: ```bash PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory inspect traj.xyz --energy-file traj.ener -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 3 -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 2 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 2 ``` If residue identity matters for your downstream analysis, convert the exported diff --git a/docs/tutorials/example-workflow.md b/docs/tutorials/example-workflow.md index bcc1d09..a0bbbb1 100644 --- a/docs/tutorials/example-workflow.md +++ b/docs/tutorials/example-workflow.md @@ -20,7 +20,7 @@ PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshel Use either a manual cutoff or the suggested one: ```bash -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 2 ``` ## Step 3: convert to PDB only if needed diff --git a/docs/tutorials/md-to-saxs-pipeline.md b/docs/tutorials/md-to-saxs-pipeline.md index 961d454..f9fa086 100644 --- a/docs/tutorials/md-to-saxs-pipeline.md +++ b/docs/tutorials/md-to-saxs-pipeline.md @@ -19,8 +19,8 @@ commands from the repository root. ```bash PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory inspect traj.xyz --energy-file traj.ener -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 3 -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 2 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 2 ``` ## Convert to residue-aware PDB, if needed diff --git a/docs/user-guide/cluster-dynamics-ml.md b/docs/user-guide/cluster-dynamics-ml.md index 4b1dc64..b111118 100644 --- a/docs/user-guide/cluster-dynamics-ml.md +++ b/docs/user-guide/cluster-dynamics-ml.md @@ -62,6 +62,42 @@ plus a structure library for the observed smaller clusters. `SAXS` tabs. 10. Save the dataset, CSV exports, or a detailed PowerPoint report if needed. +## Project-Backed CLI Setup + +For background prediction runs, use **Tools > CLI Setup > Open Cluster Dynamics +ML CLI Setup (Beta)**. The setup window writes +`cluster_dynamics_ml_cli_run.json` in the active project folder and shows the +terminal commands to run later: + +```bash +clusterdynamicsml run /path/to/project +saxshell clusterdynamicsml run /path/to/project +``` + +The command reads the project-local run file, runs the same prediction backend +used by the UI, and saves the reloadable JSON dataset plus companion CSV files, +predicted structures, SAXS profiles, histogram exports, and the inherited +cluster-dynamics heatmap/lifetime exports under +`exported_results/data/clusterdynamicsml` by default. It refreshes the +project's registered frames, clusters, energy, and experimental-data references +when a `saxs_project.json` file is present. + +You can launch the setup window directly: + +```bash +clusterdynamicsml setup-ui /path/to/project +``` + +To process several prepared projects from one terminal session, use: + +```bash +clusterdynamicsml batch-run /path/to/project_a /path/to/project_b --workers 2 --keep-going +``` + +Each project keeps its own run file and output dataset. The CLI path is +separate from the interactive UI, so the existing tabs, cached-history browser, +and plotting controls remain UI-only. + ## Training data assembled by the workflow The workflow first runs the standard cluster-dynamics analysis and then joins diff --git a/docs/user-guide/cluster-dynamics.md b/docs/user-guide/cluster-dynamics.md index 819c2ee..b8357c4 100644 --- a/docs/user-guide/cluster-dynamics.md +++ b/docs/user-guide/cluster-dynamics.md @@ -78,6 +78,40 @@ times, not the folder label. If the tool is launched from the main SAXS UI, it inherits the active project directory automatically. +## Project-Backed CLI Setup + +For longer cluster-dynamics runs, use **Tools > CLI Setup > Open Cluster +Dynamics CLI Setup (Beta)**. The setup window writes +`cluster_dynamics_cli_run.json` in the active project folder and shows the +terminal commands to run later: + +```bash +clusterdynamics run /path/to/project +saxshell clusterdynamics run /path/to/project +``` + +The command reads the project-local run file, runs the same backend used by the +UI, and saves the reloadable JSON dataset plus the heatmap CSV, lifetime CSV, +and optional energy CSV under `exported_results/data/clusterdynamics` by +default. It also refreshes the project's registered frames and energy-file +references when a `saxs_project.json` file is present. + +You can also launch the setup window directly: + +```bash +clusterdynamics setup-ui /path/to/project +``` + +To process several prepared projects from one terminal session, use: + +```bash +clusterdynamics batch-run /path/to/project_a /path/to/project_b --workers 2 --keep-going +``` + +Each project keeps its own run file and output dataset, so these runs can also +be started in separate terminals or background jobs without touching the +interactive plotting workflow. + ## Saved Outputs The save action writes a JSON dataset plus companion CSV files beside it: diff --git a/docs/user-guide/cluster-extraction.md b/docs/user-guide/cluster-extraction.md index e7ea150..abc10c9 100644 --- a/docs/user-guide/cluster-extraction.md +++ b/docs/user-guide/cluster-extraction.md @@ -34,8 +34,8 @@ Example: ```bash PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory inspect traj.xyz --energy-file traj.ener -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 3 -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 2 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 2 ``` When a cutoff is applied, the default folder name now uses the form @@ -76,6 +76,36 @@ The cluster workflow supports both UI and CLI usage. Its CLI exposes separate The CLI help text explicitly calls out faster neighbor search modes such as `kdtree` and `vectorized`. +### Project-backed cluster run files + +For repeatable project runs, launch the setup window and save a run file in the +SAXSShell project folder: + +From the main SAXSShell window, use **Tools > CLI Setup > Open Cluster +Extraction CLI Setup (Beta)**. The same setup window can also be launched from +a terminal: + +```bash +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster setup-ui /path/to/saxs_project +``` + +The setup window records the project folder, extracted frames folder, output +clusters folder, atom rules, pair cutoffs, PBC/box settings, shell options, and +neighbor-search settings in `cluster_extraction_cli_run.json`. Paths inside the +project are stored relative to the project folder so the project can move as a +unit. + +After saving, run the extraction from the terminal: + +```bash +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster run /path/to/saxs_project +``` + +Use `--run-file custom_run.json` to run a different JSON file. Relative +`--run-file` paths are resolved against the project folder. A completed run +updates the project `clusters_dir` reference to the output folder while leaving +the existing frames and PDB-frames references unchanged. + ## `clusterdynamics` This application consumes the extracted XYZ or PDB frames from `mdtrajectory` diff --git a/docs/user-guide/gui-overview.md b/docs/user-guide/gui-overview.md index 6db4b48..d1d80b9 100644 --- a/docs/user-guide/gui-overview.md +++ b/docs/user-guide/gui-overview.md @@ -193,6 +193,33 @@ settings outside the main computed-distribution flow. Use this section for smaller estimate windows such as volume-fraction, number density, attenuation, and fluorescence calculators. +### CLI Setup + +Use this section when you want the GUI to prepare a project-local run file, but +you want the heavier work to run later from a terminal. + +- `XYZ -> PDB CLI Setup` saves `xyz2pdb_cli_run.json` in the project folder so + `xyz2pdb run /path/to/project` can convert XYZ frames to PDB frames and + register the PDB output folder with the project. +- `Cluster Extraction CLI Setup` saves `cluster_extraction_cli_run.json` in the + project folder so `clusters run /path/to/project` can export clusters and + register the clusters output folder with the project. +- `Cluster Dynamics CLI Setup` saves `cluster_dynamics_cli_run.json` in the + project folder so `clusterdynamics run /path/to/project` can generate the + time-binned heatmap dataset, cluster lifetime table, and association / + dissociation rate exports from the terminal. +- `Cluster Dynamics ML CLI Setup` saves `cluster_dynamics_ml_cli_run.json` in + the project folder so `clusterdynamicsml run /path/to/project` can run the + prediction workflow, write predicted structures and SAXS/profile exports, + and keep the outputs linked to the project folder. +- `Representative CLI Setup` saves `representative_structure_cli_run.json` in + the project folder so `representativefinder run /path/to/project` can execute + the representative-selection backend without the plotting and viewer UI. + +Cluster Dynamics and Cluster Dynamics ML also provide `batch-run --workers N` +subcommands for processing multiple prepared project folders from one terminal +session. + ### (beta) Use this section for early-access workflows that are exposed from the main @@ -201,10 +228,6 @@ Use this section for early-access workflows that are exposed from the main - `Debye-Waller Analysis` estimates intra-molecular and inter-molecular Debye-Waller coefficients from sorted PDB cluster folders and saves them in the active project when requested from Project Setup or the Tools menu. -- `Representative CLI Setup` saves - `representative_structure_cli_run.json` in the project folder so - the `representativefinder` source module can execute the same representative - selection without the plotting and viewer UI. !!! warning "Debye-Waller status" The linked **Compute Debye-Waller Factors (beta)** workflow is currently diff --git a/docs/user-guide/preloaded-saxs-models.md b/docs/user-guide/preloaded-saxs-models.md index 3dbc5dd..993e716 100644 --- a/docs/user-guide/preloaded-saxs-models.md +++ b/docs/user-guide/preloaded-saxs-models.md @@ -9,16 +9,17 @@ single paper. ## Template Catalog -| Template file | GUI name | Status | Model family | -| ------------------------------------------------------ | ------------------------------------------------------ | ---------- | --------------------------------------------- | -| `template_pydream_monosq_normalized.py` | `pyDREAM MonoSQ Normalized` | current | MonoSQ hard-sphere | -| `template_pydream_monosq_normalized_scaled_solvent.py` | `pyDREAM MonoSQ Normalized (Scaled Solvent Weight)` | current | MonoSQ hard-sphere with scale-coupled solvent | -| `template_pydream_poly_lma_hs.py` | `pyDREAM Poly LMA Hard-Sphere` | current | sphere-only Poly LMA hard-sphere | -| `template_pydream_poly_lma_hs_mix_approx.py` | `pyDREAM Poly LMA Hard-Sphere/Ellipsoid Mix (Approx.)` | current | mixed-shape approximate Poly LMA hard-sphere | -| `template_likelihood_monosq.py` | `MonoSQ Basic (archived)` | archived | MonoSQ hard-sphere | -| `template_pd_likelihood_monosq.py` | `MonoSQ PD (archived)` | archived | MonoSQ hard-sphere | -| `template_pd_likelihood_monosq_decoupled.py` | `MonoSQ Decoupled (archived)` | archived | MonoSQ hard-sphere | -| `template_pydream_poly_lma_hs_legacy.py` | `pyDREAM Poly LMA Hard-Sphere (deprecated)` | deprecated | mixed-shape approximate Poly LMA hard-sphere | +| Template file | GUI name | Status | Model family | +| -------------------------------------------------------------- | ----------------------------------------------------------- | ---------- | --------------------------------------------- | +| `template_pydream_monosq_normalized.py` | `pyDREAM MonoSQ Normalized` | current | MonoSQ hard-sphere | +| `template_pydream_monosq_normalized_scaled_solvent.py` | `pyDREAM MonoSQ Normalized (Scaled Solvent Weight)` | current | MonoSQ hard-sphere with scale-coupled solvent | +| `template_pydream_charged_monosq_normalized_scaled_solvent.py` | `pyDREAM Charged MonoSQ Normalized (Scaled Solvent Weight)` | current | MonoSQ charged hard-sphere RMSA | +| `template_pydream_poly_lma_hs.py` | `pyDREAM Poly LMA Hard-Sphere` | current | sphere-only Poly LMA hard-sphere | +| `template_pydream_poly_lma_hs_mix_approx.py` | `pyDREAM Poly LMA Hard-Sphere/Ellipsoid Mix (Approx.)` | current | mixed-shape approximate Poly LMA hard-sphere | +| `template_likelihood_monosq.py` | `MonoSQ Basic (archived)` | archived | MonoSQ hard-sphere | +| `template_pd_likelihood_monosq.py` | `MonoSQ PD (archived)` | archived | MonoSQ hard-sphere | +| `template_pd_likelihood_monosq_decoupled.py` | `MonoSQ Decoupled (archived)` | archived | MonoSQ hard-sphere | +| `template_pydream_poly_lma_hs_legacy.py` | `pyDREAM Poly LMA Hard-Sphere (deprecated)` | deprecated | mixed-shape approximate Poly LMA hard-sphere | ## Shared Notation @@ -30,6 +31,8 @@ Across the bundled templates: - \(w_i\) is the raw weight assigned to component \(i\). - \(S\_{\mathrm{HS}}(q; R, \phi)\) is the hard-sphere Percus-Yevick structure factor evaluated at effective radius \(R\) and packing term \(\phi\). +- \(S\_{\mathrm{RMSA}}(q)\) is the Hayter-Penfold rescaled mean spherical + approximation charged-sphere structure factor. - `scale` and `offset` are the global multiplicative and additive terms exposed in the Prefit parameter table. @@ -124,16 +127,131 @@ Because the solvent branch is scale-coupled, Prefit's scale recommendation also treats the solvent term as part of the scaled model instead of subtracting it as an already-scaled background contribution. +### Charged Scaled Solvent MonoSQ + +The `pyDREAM Charged MonoSQ Normalized (Scaled Solvent Weight)` template keeps +the scaled-solvent MonoSQ organization, but replaces the neutral +Percus-Yevick hard-sphere term with the Hayter-Penfold RMSA structure factor for +screened Coulomb repulsion between charged spheres. + +The cluster-trace form-factor mixture is still + +$$ +I_{\mathrm{mix}}(q) = \sum_i w_i I_i(q). +$$ + +The charged solute branch is + +$$ +I_{\mathrm{solute}}(q) = +I_{\mathrm{mix}}(q) +S_{\mathrm{RMSA}} +\left(q; R_{\mathrm{eff}}, \phi, Z, T, c_{\mathrm{salt}}, \epsilon_r\right), +$$ + +and the full model follows the same scale-coupled solvent convention: + +$$ +I_{\mathrm{model}}(q) = +\mathrm{scale} +\left[ +I_{\mathrm{solute}}(q) ++ w_{\mathrm{solv}} I_{\mathrm{solv}}(q) +\right] ++ \mathrm{offset}. +$$ + +Here \(Z\) is the charged-sphere charge in elementary-charge units, +\(T\) is the absolute temperature, \(c\_{\mathrm{salt}}\) is the molar +concentration of added 1:1 electrolyte, and \(\epsilon_r\) is the solvent +relative dielectric constant. + +The implementation follows the SasView `hayter_msa` parameterization. The +template first converts the fitted parameters into SI-derived screening terms: + +$$ +\beta = \frac{1}{k_B T}, +\qquad +\epsilon = \epsilon_r \epsilon_0, +\qquad +\sigma = 2R_{\mathrm{eff}}. +$$ + +For monovalent counterions and added 1:1 salt, the ionic-strength term used by +the RMSA kernel is + +$$ +I_{\mathrm{ion}} = +\frac{e^2}{2} +\left( +\frac{Z\phi}{V_p} ++ 2 N_A 10^3 c_{\mathrm{salt}} +\right), +$$ + +where \(V*p = 4\pi R*{\mathrm{eff}}^3 / 3\) after converting \(R\_{\mathrm{eff}}\) +to meters. The Debye-Huckel screening parameter is + +$$ +\kappa = +\sqrt{\frac{2\beta I_{\mathrm{ion}}}{\epsilon}}. +$$ + +The dimensionless contact-potential parameter passed into the +Hayter-Penfold coefficient calculation is + +$$ +\Gamma = +\frac{ +\beta (Ze)^2 +}{ +\pi \epsilon \sigma (2 + \kappa\sigma)^2 +}. +$$ + +The Hayter-Penfold rescaling solves for a rescaled volume fraction +\(\phi_s\), rescaled screening parameter \(\kappa_s\), and MSA coefficients +\(A, B, C, F, U, V\) that satisfy the Gillan contact condition. SAXSShell +then evaluates the same final structure-factor form used by SasView: + +$$ +S_{\mathrm{RMSA}}(q) = +\frac{1}{1 - 24\phi_s\,\mathcal{A}(q\sigma / s)}, +$$ + +where \(s = (\phi / \phi_s)^{1/3}\) and +\(\mathcal{A}\) is the Hayter-Penfold Fourier-space coefficient expression. +The template includes the small-\(q\) Taylor branch used by SasView to avoid +rounding error near \(q\sigma / s = 0\). + +This is a charged-particle model. `charge` is constrained to be positive and +bounded above by 200 e, matching the SasView stability guidance. For neutral +systems use one of the hard-sphere MonoSQ templates instead. + +Like the scaled-solvent hard-sphere template, this charged template declares +calculator targets in its metadata: + +- `vol_frac` receives the physical solute-associated volume fraction computed + from the solution composition. +- `solv_w` receives the combined solvent-background multiplier from attenuation + and SAXS-effective solvent contrast. +- The solvent contribution is marked as globally scaled, so Prefit's autoscale + calculation treats the solvent branch as part of the model curve. + ### Variables -| Symbol / parameter | Meaning in SAXSShell | -| ------------------------------------- | -------------------------------------------------------------------------------------------------------------------------- | -| \(w_i\) | generated component weight for cluster profile \(i\) | -| \(w\_{\mathrm{solv}}\) / `solv_w` | bounded solvent contribution weight | -| \(R\_{\mathrm{eff}}\) / `eff_r` | effective hard-sphere radius used in `calc_monodisperse_sq(...)`; scaled-solvent MonoSQ defaults to 3 A | -| \(\phi\_{\mathrm{vol}}\) / `vol_frac` | effective hard-sphere volume fraction inside the Percus-Yevick term | -| `scale` | global intensity scale; original MonoSQ applies it only to solute, scaled-solvent MonoSQ applies it to solute plus solvent | -| `offset` | constant additive background | +| Symbol / parameter | Meaning in SAXSShell | +| --------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------- | +| \(w_i\) | generated component weight for cluster profile \(i\) | +| \(w\_{\mathrm{solv}}\) / `solv_w` | bounded solvent contribution weight | +| \(R\_{\mathrm{eff}}\) / `eff_r` | effective hard-sphere radius used in `calc_monodisperse_sq(...)`; scaled-solvent MonoSQ defaults to 3 A | +| \(\phi\_{\mathrm{vol}}\) / `vol_frac` | effective hard-sphere volume fraction inside the Percus-Yevick term | +| `scale` | global intensity scale; original MonoSQ applies it only to solute, scaled-solvent MonoSQ applies it to solute plus solvent | +| `offset` | constant additive background | +| \(Z\) / `charge` | charged-sphere charge in elementary-charge units for the charged RMSA template | +| \(T\) / `temperature` | absolute temperature in kelvin for the charged RMSA Debye length calculation | +| \(c\_{\mathrm{salt}}\) / `concentration_salt` | added 1:1 electrolyte concentration in mol/L for the charged RMSA template | +| \(\epsilon_r\) / `dielectconst` | solvent relative dielectric constant for the charged RMSA template | ### Likelihood conventions @@ -163,6 +281,10 @@ function before evaluating the likelihood. Phys. Rev. Lett. **10**, 321-323 (1963). - J. S. Pedersen, _Analysis of small-angle scattering data from colloids and polymer solutions: modeling and least-squares fitting_, Adv. Colloid Interface Sci. **70**, 171-210 (1997). +- J. B. Hayter and J. Penfold, Molecular Physics **42**, 109-118 (1981). +- J. P. Hansen and J. B. Hayter, Molecular Physics **46**, 651-656 (1982). +- SasView `hayter_msa` charged-sphere RMSA model documentation: + ## Poly LMA Hard-Sphere diff --git a/docs/user-guide/representative-structure-cli.md b/docs/user-guide/representative-structure-cli.md index eb516a4..d4b2149 100644 --- a/docs/user-guide/representative-structure-cli.md +++ b/docs/user-guide/representative-structure-cli.md @@ -9,7 +9,7 @@ viewer updates, and Qt progress refreshes during the actual analysis. ## Workflow 1. Open the main SAXSShell application from the source checkout. -2. Open **Tools > (beta) > Open Representative CLI Setup (Beta)**. +2. Open **Tools > CLI Setup > Open Representative CLI Setup (Beta)**. 3. Select the project folder and representative input folder. 4. Load or enter the bond-pair and angle-triplet definitions. 5. Save the run file. diff --git a/docs/user-guide/xyz2pdb-conversion.md b/docs/user-guide/xyz2pdb-conversion.md index 8048633..502eae9 100644 --- a/docs/user-guide/xyz2pdb-conversion.md +++ b/docs/user-guide/xyz2pdb-conversion.md @@ -440,6 +440,41 @@ The standalone `xyz2pdb` command still exposes older CLI subcommands such as scripting. The Qt interface documented here is the newer native mapping UI and does not require the legacy JSON input file. +## Project-backed CLI runs + +Use the setup window to save a project-local run file for repeatable +conversions: + +From the main SAXSShell window, use **Tools > CLI Setup > Open XYZ -> PDB CLI +Setup (Beta)**. The same setup window can also be launched from a terminal: + +```bash +xyz2pdb setup-ui /path/to/saxshell_project --input-path /path/to/xyz_frames +``` + +The setup window records the XYZ input, output PDB folder, reference library, +free-atom definitions, reference-molecule mappings, hydrogen handling, optional +PBC JSON, assertion mode, and selected estimate solution. Paths inside the +project folder are saved relative to the project, so the run file remains +portable with the project folder. + +After saving, run the conversion from the project folder: + +```bash +xyz2pdb run /path/to/saxshell_project +``` + +By default this reads `xyz2pdb_cli_run.json` in the project folder. To use a +different run file: + +```bash +xyz2pdb run /path/to/saxshell_project --run-file /path/to/run.json +``` + +The project-backed run uses the same native headless mapping workflow as the +GUI export, writes the converted PDB frames, and updates the SAXSShell project +`PDB structure folder` to point at the output directory. + ## Related pages - [GUI Overview](gui-overview.md) diff --git a/src/saxshell/cluster/__init__.py b/src/saxshell/cluster/__init__.py index 454887b..5078989 100644 --- a/src/saxshell/cluster/__init__.py +++ b/src/saxshell/cluster/__init__.py @@ -37,6 +37,20 @@ ordered_cluster_extraction_preset_names, save_custom_cluster_extraction_preset, ) +from .run_config import ( + ClusterRunConfig, + ClusterRunExecutionSummary, + build_cluster_run_config, + default_cluster_run_file_path, + load_cluster_run_config, + path_text_for_run_config, + preview_cluster_run_config, + resolve_run_config_path, + run_cluster_run_config, + save_cluster_run_config, + suggest_run_config_output_dir, + workflow_from_cluster_run_config, +) from .workflow import ( ClusterExportResult, ClusterSelectionResult, @@ -90,4 +104,16 @@ "load_cluster_extraction_presets", "ordered_cluster_extraction_preset_names", "save_custom_cluster_extraction_preset", + "ClusterRunConfig", + "ClusterRunExecutionSummary", + "build_cluster_run_config", + "default_cluster_run_file_path", + "load_cluster_run_config", + "path_text_for_run_config", + "preview_cluster_run_config", + "resolve_run_config_path", + "run_cluster_run_config", + "save_cluster_run_config", + "suggest_run_config_output_dir", + "workflow_from_cluster_run_config", ] diff --git a/src/saxshell/cluster/_cluster_extraction_presets/user_cluster_extraction_presets.json b/src/saxshell/cluster/_cluster_extraction_presets/user_cluster_extraction_presets.json index 5e8ea8e..04b560f 100644 --- a/src/saxshell/cluster/_cluster_extraction_presets/user_cluster_extraction_presets.json +++ b/src/saxshell/cluster/_cluster_extraction_presets/user_cluster_extraction_presets.json @@ -92,6 +92,194 @@ "shared_shells": true, "include_shell_atoms_in_stoichiometry": false } + }, + "MAPbI3 - DMSO (Full Solvent)": { + "atom_type_definitions": { + "node": [ + { + "element": "Pb", + "residue": "PBI" + } + ], + "linker": [ + { + "element": "I", + "residue": "PBI" + } + ], + "shell": [ + { + "element": "O", + "residue": "DMS" + } + ] + }, + "pair_cutoff_definitions": [ + { + "atom1": "Pb", + "atom2": "I", + "shell_cutoffs": { + "0": 3.36 + } + }, + { + "atom1": "Pb", + "atom2": "O", + "shell_cutoffs": { + "0": 3.36 + } + } + ], + "options": { + "use_pbc": false, + "search_mode": "kdtree", + "save_state_frequency": 1000, + "shell_growth_levels": [], + "shared_shells": true, + "smart_solvation_shells": true, + "include_shell_atoms_in_stoichiometry": false + } + }, + "PbI2 - DMSO (Full Solvent)": { + "atom_type_definitions": { + "node": [ + { + "element": "Pb", + "residue": "PBI" + } + ], + "linker": [ + { + "element": "I", + "residue": "PBI" + } + ], + "shell": [ + { + "element": "O", + "residue": "DMS" + } + ] + }, + "pair_cutoff_definitions": [ + { + "atom1": "Pb", + "atom2": "I", + "shell_cutoffs": { + "0": 3.36 + } + }, + { + "atom1": "Pb", + "atom2": "O", + "shell_cutoffs": { + "0": 3.36 + } + } + ], + "options": { + "use_pbc": false, + "search_mode": "kdtree", + "save_state_frequency": 1000, + "shell_growth_levels": [], + "shared_shells": true, + "smart_solvation_shells": true, + "include_shell_atoms_in_stoichiometry": false + } + }, + "PbI2 - DMF (Full Solvent)": { + "atom_type_definitions": { + "node": [ + { + "element": "Pb", + "residue": "PBI" + } + ], + "linker": [ + { + "element": "I", + "residue": "PBI" + } + ], + "shell": [ + { + "element": "O", + "residue": "DMF" + } + ] + }, + "pair_cutoff_definitions": [ + { + "atom1": "Pb", + "atom2": "I", + "shell_cutoffs": { + "0": 3.32 + } + }, + { + "atom1": "Pb", + "atom2": "O", + "shell_cutoffs": { + "0": 3.32 + } + } + ], + "options": { + "use_pbc": false, + "search_mode": "kdtree", + "save_state_frequency": 1000, + "shell_growth_levels": [], + "shared_shells": true, + "smart_solvation_shells": true, + "include_shell_atoms_in_stoichiometry": false + } + }, + "MAPbI3 - DMF (Full Solvent)": { + "atom_type_definitions": { + "node": [ + { + "element": "Pb", + "residue": "PBI" + } + ], + "linker": [ + { + "element": "I", + "residue": "PBI" + } + ], + "shell": [ + { + "element": "O", + "residue": "DMF" + } + ] + }, + "pair_cutoff_definitions": [ + { + "atom1": "Pb", + "atom2": "I", + "shell_cutoffs": { + "0": 3.45 + } + }, + { + "atom1": "Pb", + "atom2": "O", + "shell_cutoffs": { + "0": 3.45 + } + } + ], + "options": { + "use_pbc": false, + "search_mode": "kdtree", + "save_state_frequency": 1000, + "shell_growth_levels": [], + "shared_shells": true, + "smart_solvation_shells": true, + "include_shell_atoms_in_stoichiometry": false + } } } } diff --git a/src/saxshell/cluster/cli.py b/src/saxshell/cluster/cli.py index b853af9..67f3c0a 100644 --- a/src/saxshell/cluster/cli.py +++ b/src/saxshell/cluster/cli.py @@ -6,6 +6,11 @@ from saxshell.version import __version__ from .clusternetwork import DEFAULT_SAVE_STATE_FREQUENCY, SEARCH_MODE_CHOICES +from .run_config import ( + default_cluster_run_file_path, + load_cluster_run_config, + run_cluster_run_config, +) from .workflow import ( ClusterExportResult, ClusterSelectionResult, @@ -46,6 +51,24 @@ def build_parser() -> argparse.ArgumentParser: subparsers = parser.add_subparsers(dest="command") + setup_ui_parser = subparsers.add_parser( + "setup-ui", + help="Launch the beta project-backed run-file setup interface.", + ) + setup_ui_parser.add_argument( + "project_dir", + nargs="?", + type=Path, + help="Optional SAXSShell project folder.", + ) + setup_ui_parser.add_argument( + "--frames-dir", + type=Path, + default=None, + help="Optional extracted PDB or XYZ frames folder to prefill.", + ) + setup_ui_parser.set_defaults(handler=_handle_setup_ui) + ui_parser = subparsers.add_parser("ui", help="Launch the Qt UI.") ui_parser.add_argument( "frames_dir", @@ -88,6 +111,26 @@ def build_parser() -> argparse.ArgumentParser: ) export_parser.set_defaults(handler=_handle_export) + run_parser = subparsers.add_parser( + "run", + help="Run cluster extraction from a project-backed run file.", + ) + run_parser.add_argument( + "project_dir", + type=Path, + help="SAXSShell project folder containing the run file.", + ) + run_parser.add_argument( + "--run-file", + type=Path, + default=None, + help=( + "Run file path. Defaults to cluster_extraction_cli_run.json " + "in the project folder." + ), + ) + run_parser.set_defaults(handler=_handle_run) + return parser @@ -222,6 +265,22 @@ def _handle_ui(args: argparse.Namespace) -> int: return launch_cluster_ui(getattr(args, "frames_dir", None)) +def _handle_setup_ui(args: argparse.Namespace) -> int: + from PySide6.QtWidgets import QApplication + + from .ui.run_file_window import launch_cluster_run_file_ui + + owns_app = QApplication.instance() is None + launch_cluster_run_file_ui( + initial_project_dir=getattr(args, "project_dir", None), + initial_frames_dir=getattr(args, "frames_dir", None), + ) + app = QApplication.instance() + if owns_app and app is not None: + return app.exec() + return 0 + + def _build_workflow(args: argparse.Namespace) -> ClusterWorkflow: return ClusterWorkflow( frames_dir=args.frames_dir, @@ -349,6 +408,36 @@ def _handle_export(args: argparse.Namespace) -> int: return 0 +def _handle_run(args: argparse.Namespace) -> int: + project_dir = Path(args.project_dir).expanduser().resolve() + run_file = _resolve_run_file(project_dir, args.run_file) + config = load_cluster_run_config(run_file) + summary = run_cluster_run_config( + project_dir, + config, + run_file_path=run_file, + log_callback=print, + ) + print("") + print("Cluster extraction CLI run complete") + print(f"Frames folder: {summary.frames_dir}") + print(f"Output folder: {summary.output_dir}") + print(f"Frames analyzed: {summary.result.analyzed_frames}") + print(f"Clusters found: {summary.result.total_clusters}") + print(f"Files written: {summary.written_count}") + print(f"Project file: {summary.project_file}") + return 0 + + +def _resolve_run_file(project_dir: Path, run_file: Path | None) -> Path: + if run_file is None: + return default_cluster_run_file_path(project_dir) + path = Path(run_file).expanduser() + if not path.is_absolute(): + path = project_dir / path + return path.resolve() + + def _format_selection_result(selection: ClusterSelectionResult) -> str: stoichiometry_bins_text = ( "solute + shell atoms" diff --git a/src/saxshell/cluster/clusternetwork.py b/src/saxshell/cluster/clusternetwork.py index 88a29da..e9cf130 100644 --- a/src/saxshell/cluster/clusternetwork.py +++ b/src/saxshell/cluster/clusternetwork.py @@ -2957,6 +2957,31 @@ def _export_cluster_pdb_files_with_smart_shells( atom_elements: dict[int, str] = {} active_runs = self._rebuild_smart_shell_run_states(frame_entries) + def finalize_smart_shell_runs( + runs: Sequence[_SmartShellRunState], + ) -> None: + for run in runs: + self._apply_smart_shell_union_to_run( + run, + frame_entries, + elements=atom_elements, + ) + + def ensure_atom_elements_loaded() -> None: + if atom_elements: + return + for frame_path in frame_paths: + entry = frame_entries.get(frame_path.name) + if entry is None or not _frame_entry_is_processed(entry): + continue + network = self._build_network(frame_path) + if not isinstance(network, ClusterNetwork): + raise ValueError( + "Smart Solvation Shell mode requires PDB frames." + ) + atom_elements.update(network.elements) + return + def checkpoint_metadata(*, force: bool = False) -> None: nonlocal frames_since_checkpoint, last_checkpoint_time if not force: @@ -2985,6 +3010,13 @@ def checkpoint_metadata(*, force: bool = False) -> None: "Smart Solvation Shell mode requires PDB frames." ) atom_elements.update(network.elements) + if active_runs and any( + run.last_frame_index != frame_index - 1 + for run in active_runs.values() + ): + finalize_smart_shell_runs(tuple(active_runs.values())) + active_runs = {} + clusters = network.find_clusters( shell_levels=shell_levels, shared_shells=shared_shells, @@ -3012,6 +3044,7 @@ def checkpoint_metadata(*, force: bool = False) -> None: } current_runs: dict[tuple[int, ...], _SmartShellRunState] = {} + continued_run_keys: set[tuple[int, ...]] = set() for cluster in frame_result.clusters: solute_atom_ids = cluster.solute_atom_ids prior_run = active_runs.get(solute_atom_ids) @@ -3020,6 +3053,7 @@ def checkpoint_metadata(*, force: bool = False) -> None: and prior_run.last_frame_index == frame_index - 1 ): run = prior_run + continued_run_keys.add(solute_atom_ids) else: run = _SmartShellRunState( solute_atom_ids=solute_atom_ids, @@ -3038,12 +3072,12 @@ def checkpoint_metadata(*, force: bool = False) -> None: run.shell_levels[atom_id] = shell_level current_runs[solute_atom_ids] = run - for run in current_runs.values(): - self._apply_smart_shell_union_to_run( - run, - frame_entries, - elements=atom_elements, - ) + closed_runs = [ + run + for key, run in active_runs.items() + if key not in continued_run_keys + ] + finalize_smart_shell_runs(closed_runs) active_runs = current_runs newly_processed_frames += 1 @@ -3058,6 +3092,11 @@ def checkpoint_metadata(*, force: bool = False) -> None: frame_path.stem, ) + if active_runs: + ensure_atom_elements_loaded() + finalize_smart_shell_runs(tuple(active_runs.values())) + active_runs = {} + metadata["state"] = "sorting" checkpoint_metadata(force=True) if phase_callback is not None: @@ -3207,7 +3246,11 @@ def _rebuild_smart_shell_run_states( ) run.last_frame_index = frame_index run.frame_refs.append((frame_name, cluster.cluster_id)) - run.shell_levels = _cluster_shell_level_payload(cluster) + for atom_id, shell_level in _cluster_shell_level_payload( + cluster + ).items(): + if atom_id not in run.shell_levels: + run.shell_levels[atom_id] = shell_level current_runs[solute_atom_ids] = run active_runs = current_runs return active_runs diff --git a/src/saxshell/cluster/run_config.py b/src/saxshell/cluster/run_config.py new file mode 100644 index 0000000..519d7c5 --- /dev/null +++ b/src/saxshell/cluster/run_config.py @@ -0,0 +1,518 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Callable + +from saxshell.structure import ( + AtomTypeDefinitions, + normalize_atom_type_definitions, +) + +from .clusternetwork import ( + DEFAULT_SAVE_STATE_FREQUENCY, + SEARCH_MODE_KDTREE, + PairCutoffDefinitions, + normalize_pair_cutoffs, + normalize_save_state_frequency, + normalize_search_mode, +) +from .workflow import ( + ClusterExportResult, + ClusterWorkflow, + example_atom_type_definitions, + example_pair_cutoff_definitions, + suggest_output_dir, +) + +DEFAULT_RUN_FILE_NAME = "cluster_extraction_cli_run.json" +RUN_CONFIG_VERSION = 1 +ClusterRunLogCallback = Callable[[str], None] + + +@dataclass(slots=True) +class ClusterRunConfig: + frames_dir: str + output_dir: str | None + atom_type_definitions: AtomTypeDefinitions = field( + default_factory=example_atom_type_definitions + ) + pair_cutoff_definitions: PairCutoffDefinitions = field( + default_factory=example_pair_cutoff_definitions + ) + box_dimensions: tuple[float, float, float] | None = None + use_pbc: bool = False + default_cutoff: float | None = None + shell_levels: tuple[int, ...] = () + include_shell_levels: tuple[int, ...] = (0,) + shared_shells: bool = False + smart_solvation_shells: bool = True + include_shell_atoms_in_stoichiometry: bool = False + search_mode: str = SEARCH_MODE_KDTREE + save_state_frequency: int = DEFAULT_SAVE_STATE_FREQUENCY + created_at: str = field( + default_factory=lambda: datetime.now().isoformat(timespec="seconds") + ) + + def to_dict(self) -> dict[str, object]: + return { + "version": RUN_CONFIG_VERSION, + "created_at": self.created_at, + "frames_dir": self.frames_dir, + "output_dir": self.output_dir, + "atom_type_definitions": _serialize_atom_type_definitions( + self.atom_type_definitions + ), + "pair_cutoff_definitions": _serialize_pair_cutoff_definitions( + self.pair_cutoff_definitions + ), + "box_dimensions": self.box_dimensions, + "use_pbc": bool(self.use_pbc), + "default_cutoff": self.default_cutoff, + "shell_levels": [int(level) for level in self.shell_levels], + "include_shell_levels": [ + int(level) for level in self.include_shell_levels + ], + "shared_shells": bool(self.shared_shells), + "smart_solvation_shells": bool(self.smart_solvation_shells), + "include_shell_atoms_in_stoichiometry": bool( + self.include_shell_atoms_in_stoichiometry + ), + "search_mode": normalize_search_mode(self.search_mode), + "save_state_frequency": normalize_save_state_frequency( + self.save_state_frequency + ), + } + + @classmethod + def from_dict(cls, payload: dict[str, object]) -> "ClusterRunConfig": + frames_dir = str(payload.get("frames_dir", "")).strip() + if not frames_dir: + raise ValueError("Cluster run file is missing frames_dir.") + return cls( + frames_dir=frames_dir, + output_dir=_optional_text(payload.get("output_dir")), + atom_type_definitions=_coerce_atom_type_definitions( + payload.get("atom_type_definitions") + ), + pair_cutoff_definitions=_coerce_pair_cutoff_definitions( + payload.get("pair_cutoff_definitions") + ), + box_dimensions=_coerce_box_dimensions( + payload.get("box_dimensions") + ), + use_pbc=bool(payload.get("use_pbc", False)), + default_cutoff=_optional_float(payload.get("default_cutoff")), + shell_levels=_coerce_int_tuple(payload.get("shell_levels")), + include_shell_levels=( + _coerce_int_tuple(payload.get("include_shell_levels")) or (0,) + ), + shared_shells=bool(payload.get("shared_shells", False)), + smart_solvation_shells=bool( + payload.get("smart_solvation_shells", True) + ), + include_shell_atoms_in_stoichiometry=bool( + payload.get("include_shell_atoms_in_stoichiometry", False) + ), + search_mode=normalize_search_mode( + str(payload.get("search_mode", SEARCH_MODE_KDTREE)) + ), + save_state_frequency=normalize_save_state_frequency( + _optional_int( + payload.get("save_state_frequency"), + DEFAULT_SAVE_STATE_FREQUENCY, + ) + ), + created_at=str(payload.get("created_at", "")).strip() + or datetime.now().isoformat(timespec="seconds"), + ) + + +@dataclass(slots=True, frozen=True) +class ClusterRunExecutionSummary: + project_dir: Path + run_file_path: Path | None + frames_dir: Path + output_dir: Path + result: ClusterExportResult + project_file: Path + + @property + def written_count(self) -> int: + return len(self.result.written_files) + + +def default_cluster_run_file_path(project_dir: str | Path) -> Path: + return Path(project_dir).expanduser().resolve() / DEFAULT_RUN_FILE_NAME + + +def save_cluster_run_config( + output_path: str | Path, + config: ClusterRunConfig, +) -> Path: + path = Path(output_path).expanduser().resolve() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(config.to_dict(), indent=2) + "\n", + encoding="utf-8", + ) + return path + + +def load_cluster_run_config(run_file_path: str | Path) -> ClusterRunConfig: + path = Path(run_file_path).expanduser().resolve() + payload = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError( + f"Cluster run file must contain a JSON object: {path}" + ) + return ClusterRunConfig.from_dict(payload) + + +def path_text_for_run_config( + path: str | Path | None, + *, + project_dir: str | Path, +) -> str | None: + if path is None: + return None + resolved_project_dir = Path(project_dir).expanduser().resolve() + resolved_path = Path(path).expanduser().resolve() + try: + return resolved_path.relative_to(resolved_project_dir).as_posix() + except ValueError: + return str(resolved_path) + + +def resolve_run_config_path( + path_text: str | None, + *, + project_dir: str | Path, +) -> Path | None: + text = str(path_text or "").strip() + if not text: + return None + path = Path(text).expanduser() + if not path.is_absolute(): + path = Path(project_dir).expanduser().resolve() / path + return path.resolve() + + +def build_cluster_run_config( + *, + project_dir: str | Path, + frames_dir: str | Path, + output_dir: str | Path | None, + atom_type_definitions: AtomTypeDefinitions, + pair_cutoff_definitions: PairCutoffDefinitions, + box_dimensions: tuple[float, float, float] | None = None, + use_pbc: bool = False, + default_cutoff: float | None = None, + shell_levels: tuple[int, ...] = (), + include_shell_levels: tuple[int, ...] = (0,), + shared_shells: bool = False, + smart_solvation_shells: bool = True, + include_shell_atoms_in_stoichiometry: bool = False, + search_mode: str = SEARCH_MODE_KDTREE, + save_state_frequency: int = DEFAULT_SAVE_STATE_FREQUENCY, +) -> ClusterRunConfig: + return ClusterRunConfig( + frames_dir=path_text_for_run_config( + frames_dir, + project_dir=project_dir, + ) + or "", + output_dir=path_text_for_run_config( + output_dir, + project_dir=project_dir, + ), + atom_type_definitions=normalize_atom_type_definitions( + atom_type_definitions + ), + pair_cutoff_definitions=normalize_pair_cutoffs( + pair_cutoff_definitions + ), + box_dimensions=box_dimensions, + use_pbc=bool(use_pbc), + default_cutoff=default_cutoff, + shell_levels=tuple(sorted({int(level) for level in shell_levels})), + include_shell_levels=tuple( + sorted({int(level) for level in include_shell_levels}) + ) + or (0,), + shared_shells=bool(shared_shells), + smart_solvation_shells=bool(smart_solvation_shells), + include_shell_atoms_in_stoichiometry=bool( + include_shell_atoms_in_stoichiometry + ), + search_mode=normalize_search_mode(search_mode), + save_state_frequency=normalize_save_state_frequency( + save_state_frequency + ), + ) + + +def suggest_run_config_output_dir( + *, + frames_dir: str | Path, +) -> Path: + return suggest_output_dir(frames_dir) + + +def workflow_from_cluster_run_config( + *, + project_dir: str | Path, + config: ClusterRunConfig, +) -> ClusterWorkflow: + resolved_project_dir = Path(project_dir).expanduser().resolve() + frames_dir = resolve_run_config_path( + config.frames_dir, + project_dir=resolved_project_dir, + ) + if frames_dir is None: + raise ValueError("Cluster run file is missing frames_dir.") + return ClusterWorkflow( + frames_dir=frames_dir, + atom_type_definitions=config.atom_type_definitions, + pair_cutoff_definitions=config.pair_cutoff_definitions, + box_dimensions=config.box_dimensions, + use_pbc=config.use_pbc, + default_cutoff=config.default_cutoff, + shell_levels=config.shell_levels, + include_shell_levels=config.include_shell_levels, + shared_shells=config.shared_shells, + smart_solvation_shells=config.smart_solvation_shells, + include_shell_atoms_in_stoichiometry=( + config.include_shell_atoms_in_stoichiometry + ), + search_mode=config.search_mode, + save_state_frequency=config.save_state_frequency, + ) + + +def preview_cluster_run_config( + *, + project_dir: str | Path, + config: ClusterRunConfig, +) -> dict[str, object]: + resolved_project_dir = Path(project_dir).expanduser().resolve() + output_dir = resolve_run_config_path( + config.output_dir, + project_dir=resolved_project_dir, + ) + workflow = workflow_from_cluster_run_config( + project_dir=resolved_project_dir, + config=config, + ) + return workflow.preview_selection(output_dir=output_dir).to_dict() + + +def run_cluster_run_config( + project_dir: str | Path, + config: ClusterRunConfig, + *, + run_file_path: str | Path | None = None, + log_callback: ClusterRunLogCallback | None = None, +) -> ClusterRunExecutionSummary: + resolved_project_dir = Path(project_dir).expanduser().resolve() + frames_dir = resolve_run_config_path( + config.frames_dir, + project_dir=resolved_project_dir, + ) + if frames_dir is None: + raise ValueError("Cluster run file is missing frames_dir.") + output_dir = resolve_run_config_path( + config.output_dir, + project_dir=resolved_project_dir, + ) + workflow = workflow_from_cluster_run_config( + project_dir=resolved_project_dir, + config=config, + ) + _emit_log(log_callback, f"Frames folder: {frames_dir}") + _emit_log( + log_callback, + "Output folder: " + + str( + output_dir + if output_dir is not None + else suggest_output_dir(frames_dir) + ), + ) + result = workflow.export_clusters(output_dir=output_dir) + project_file = _register_project_clusters_dir( + resolved_project_dir, + result.output_dir, + ) + _emit_log(log_callback, f"Project clusters folder: {result.output_dir}") + return ClusterRunExecutionSummary( + project_dir=resolved_project_dir, + run_file_path=( + None if run_file_path is None else Path(run_file_path).resolve() + ), + frames_dir=frames_dir, + output_dir=result.output_dir, + result=result, + project_file=project_file, + ) + + +def _register_project_clusters_dir( + project_dir: Path, clusters_dir: Path +) -> Path: + from saxshell.saxs.project_manager import SAXSProjectManager + + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.clusters_dir = str(Path(clusters_dir).expanduser().resolve()) + return manager.save_project(settings) + + +def _serialize_atom_type_definitions( + definitions: AtomTypeDefinitions, +) -> dict[str, list[dict[str, str | None]]]: + normalized = normalize_atom_type_definitions(definitions) + return { + atom_type: [ + {"element": element, "residue": residue} + for element, residue in entries + ] + for atom_type, entries in normalized.items() + } + + +def _serialize_pair_cutoff_definitions( + definitions: PairCutoffDefinitions, +) -> list[dict[str, object]]: + normalized = normalize_pair_cutoffs(definitions) + payload: list[dict[str, object]] = [] + for atom1, atom2 in sorted(normalized): + payload.append( + { + "atom1": atom1, + "atom2": atom2, + "shell_cutoffs": { + str(level): float(cutoff) + for level, cutoff in sorted( + normalized[(atom1, atom2)].items() + ) + }, + } + ) + return payload + + +def _coerce_atom_type_definitions(value: object) -> AtomTypeDefinitions: + if not isinstance(value, dict): + return example_atom_type_definitions() + definitions: AtomTypeDefinitions = {} + for atom_type, entries in value.items(): + if not isinstance(entries, list): + continue + parsed: list[tuple[str, str | None]] = [] + for entry in entries: + if isinstance(entry, dict): + element_value = entry.get("element") + residue_value = entry.get("residue") + elif isinstance(entry, (list, tuple)): + element_value = entry[0] if len(entry) >= 1 else None + residue_value = entry[1] if len(entry) >= 2 else None + else: + element_value = entry + residue_value = None + element = str(element_value or "").strip().title() + residue_text = str(residue_value or "").strip() + if element: + parsed.append((element, residue_text or None)) + if parsed: + definitions[str(atom_type).strip()] = parsed + return normalize_atom_type_definitions(definitions) + + +def _coerce_pair_cutoff_definitions(value: object) -> PairCutoffDefinitions: + if not isinstance(value, list): + return example_pair_cutoff_definitions() + definitions: PairCutoffDefinitions = {} + for entry in value: + if not isinstance(entry, dict): + continue + atom1 = str(entry.get("atom1", "")).strip().title() + atom2 = str(entry.get("atom2", "")).strip().title() + cutoffs = entry.get("shell_cutoffs") + if not atom1 or not atom2 or not isinstance(cutoffs, dict): + continue + parsed: dict[int, float] = {} + for level, cutoff in cutoffs.items(): + parsed[int(level)] = float(cutoff) + if parsed: + definitions[(atom1, atom2)] = parsed + return normalize_pair_cutoffs(definitions) + + +def _coerce_box_dimensions( + value: object, +) -> tuple[float, float, float] | None: + if value is None: + return None + if not isinstance(value, (list, tuple)): + raise ValueError("box_dimensions must be a list of three numbers.") + box = tuple(float(component) for component in value) + if len(box) != 3: + raise ValueError("box_dimensions must contain exactly three numbers.") + return box + + +def _coerce_int_tuple(value: object) -> tuple[int, ...]: + if not isinstance(value, (list, tuple)): + return () + return tuple(sorted({int(entry) for entry in value})) + + +def _optional_text(value: object) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None + + +def _optional_float(value: object) -> float | None: + if value is None: + return None + text = str(value).strip() + if not text: + return None + result = float(text) + return result if result > 0.0 else None + + +def _optional_int(value: object, default: int) -> int: + if value is None: + return int(default) + text = str(value).strip() + if not text: + return int(default) + return int(text) + + +def _emit_log(callback: ClusterRunLogCallback | None, message: str) -> None: + if callback is not None: + callback(str(message).strip()) + + +__all__ = [ + "DEFAULT_RUN_FILE_NAME", + "ClusterRunConfig", + "ClusterRunExecutionSummary", + "build_cluster_run_config", + "default_cluster_run_file_path", + "load_cluster_run_config", + "path_text_for_run_config", + "preview_cluster_run_config", + "resolve_run_config_path", + "run_cluster_run_config", + "save_cluster_run_config", + "suggest_run_config_output_dir", + "workflow_from_cluster_run_config", +] diff --git a/src/saxshell/cluster/ui/__init__.py b/src/saxshell/cluster/ui/__init__.py index d3bad8f..b40091c 100644 --- a/src/saxshell/cluster/ui/__init__.py +++ b/src/saxshell/cluster/ui/__init__.py @@ -1,5 +1,18 @@ """Qt6 UI for the cluster extraction application.""" +from .batch_queue_window import ( + ClusterBatchQueueWindow, + launch_cluster_batch_queue_ui, +) from .main_window import ClusterMainWindow, launch_cluster_ui, main +from .run_file_window import ClusterRunFileWindow, launch_cluster_run_file_ui -__all__ = ["ClusterMainWindow", "launch_cluster_ui", "main"] +__all__ = [ + "ClusterBatchQueueWindow", + "ClusterMainWindow", + "ClusterRunFileWindow", + "launch_cluster_batch_queue_ui", + "launch_cluster_run_file_ui", + "launch_cluster_ui", + "main", +] diff --git a/src/saxshell/cluster/ui/batch_queue_window.py b/src/saxshell/cluster/ui/batch_queue_window.py new file mode 100644 index 0000000..0a68da2 --- /dev/null +++ b/src/saxshell/cluster/ui/batch_queue_window.py @@ -0,0 +1,1300 @@ +from __future__ import annotations + +import threading +import uuid +from dataclasses import dataclass, field, replace +from pathlib import Path + +from PySide6.QtCore import QObject, Qt, QThread, Signal, Slot +from PySide6.QtWidgets import ( + QAbstractItemView, + QApplication, + QFileDialog, + QFormLayout, + QFrame, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QListView, + QListWidget, + QListWidgetItem, + QMainWindow, + QMessageBox, + QProgressBar, + QPushButton, + QSizePolicy, + QTextEdit, + QToolButton, + QTreeView, + QVBoxLayout, + QWidget, +) + +from saxshell.cluster import ( + DEFAULT_CLUSTER_EXTRACTION_PRESET_NAME, + DEFAULT_SAVE_STATE_FREQUENCY, + ExtractedFrameFolderClusterAnalyzer, + PairCutoffDefinitions, + format_box_dimensions, + frame_folder_label, +) +from saxshell.cluster.ui.definitions_panel import ClusterDefinitionsPanel +from saxshell.cluster.ui.main_window import ( + ClusterExportResult, + ClusterExportWorker, + ClusterJobConfig, + suggest_cluster_output_dir, +) +from saxshell.saxs.project_manager import ( + SAXSProjectManager, + build_project_paths, +) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) +from saxshell.structure import AtomTypeDefinitions + + +def _new_item_id() -> str: + return uuid.uuid4().hex + + +def _optional_path(text: str) -> Path | None: + stripped = text.strip() + if not stripped: + return None + return Path(stripped).expanduser().resolve() + + +def _required_path(text: str, field_name: str) -> Path: + path = _optional_path(text) + if path is None: + raise ValueError(f"{field_name} is required.") + return path + + +def _required_project_dir(text: str) -> Path: + project_dir = _required_path(text, "Project folder") + project_file = build_project_paths(project_dir).project_file + if not project_file.is_file(): + raise ValueError(f"Project file does not exist: {project_file}") + return project_dir + + +def _required_frames_dir(text: str) -> Path: + frames_dir = _required_path(text, "Frames folder") + if not frames_dir.is_dir(): + raise ValueError(f"Frames folder does not exist: {frames_dir}") + return frames_dir + + +def _dialog_start_dir(*candidates: str | Path | None) -> str: + for candidate in candidates: + if candidate is None: + continue + path = Path(candidate).expanduser() + if path.is_file(): + return str(path.parent) + if path.is_dir(): + return str(path) + return str(Path.home()) + + +def _choose_existing_directories( + parent: QWidget, + *, + title: str, + start_dir: str | Path, +) -> tuple[Path, ...]: + dialog = QFileDialog(parent, title, str(start_dir)) + dialog.setFileMode(QFileDialog.FileMode.Directory) + dialog.setOption(QFileDialog.Option.ShowDirsOnly, True) + dialog.setOption(QFileDialog.Option.DontUseNativeDialog, True) + for view in dialog.findChildren(QListView) + dialog.findChildren( + QTreeView + ): + view.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + if dialog.exec() != int(QFileDialog.DialogCode.Accepted): + return () + return tuple( + Path(path).expanduser().resolve() for path in dialog.selectedFiles() + ) + + +def _copy_atom_type_definitions( + definitions: AtomTypeDefinitions, +) -> AtomTypeDefinitions: + return { + atom_type: list(criteria) + for atom_type, criteria in definitions.items() + } + + +def _copy_pair_cutoff_definitions( + definitions: PairCutoffDefinitions, +) -> PairCutoffDefinitions: + return {pair: dict(levels) for pair, levels in definitions.items()} + + +def _summary_box_dimensions( + summary: dict[str, object] | None, +) -> tuple[float, float, float] | None: + if summary is None: + return None + value = summary.get("box_dimensions") + if value is None: + value = summary.get("estimated_box_dimensions") + if value is None: + return None + return tuple(float(component) for component in value) + + +def _summary_text(summary: dict[str, object]) -> str: + source_kind = summary.get("box_dimensions_source_kind") + box_label = ( + "Source box dimensions" + if source_kind == "source_filename" + else "Estimated box dimensions" + ) + lines = [ + f"Frames folder: {summary.get('input_dir')}", + f"Mode: {summary.get('mode_label')}", + f"Frames: {summary.get('n_frames')}", + f"Output format: {summary.get('output_file_extension')}", + f"{box_label}: {format_box_dimensions(_summary_box_dimensions(summary))}", + ] + if summary.get("box_dimensions_source") is not None: + lines.append(f"Box source: {summary.get('box_dimensions_source')}") + return "\n".join(lines) + + +def _source_kind_for_project_settings(settings: object) -> str: + pdb_frames_dir = getattr(settings, "resolved_pdb_frames_dir", None) + if pdb_frames_dir is not None: + return "pdb" + return "xyz" + + +def _frames_dir_for_project_settings(settings: object) -> Path | None: + return getattr(settings, "resolved_pdb_frames_dir", None) or getattr( + settings, "resolved_frames_dir", None + ) + + +@dataclass(slots=True) +class ClusterBatchJob: + project_dir: Path + frames_dir: Path + frames_source_kind: str + config: ClusterJobConfig + + +@dataclass(slots=True) +class ClusterBatchResult: + project_dir: Path + frames_dir: Path + frames_source_kind: str + output_dir: Path + analyzed_frames: int + total_clusters: int + written_count: int + + +@dataclass(slots=True) +class ClusterBatchItem: + item_id: str + project_dir: Path | None = None + frames_dir: Path | None = None + frames_source_kind: str = "pdb" + output_dir: Path | None = None + atom_type_definitions: AtomTypeDefinitions = field(default_factory=dict) + pair_cutoff_definitions: PairCutoffDefinitions = field( + default_factory=dict + ) + box_dimensions: tuple[float, float, float] | None = None + use_pbc: bool = False + search_mode: str = "kdtree" + save_state_frequency: int = DEFAULT_SAVE_STATE_FREQUENCY + default_cutoff: float | None = None + shell_levels: tuple[int, ...] = () + include_shell_levels: tuple[int, ...] = (0,) + shared_shells: bool = False + smart_solvation_shells: bool = True + include_shell_atoms_in_stoichiometry: bool = False + + def display_name(self) -> str: + if self.project_dir is not None: + return self.project_dir.name + if self.frames_dir is not None: + return self.frames_dir.name + return "New cluster extraction" + + +def _queue_item_from_project_defaults( + project_dir: str | Path, + *, + item_id: str | None = None, +) -> ClusterBatchItem: + resolved_project_dir = Path(project_dir).expanduser().resolve() + item = ClusterBatchItem( + item_id=item_id or _new_item_id(), + project_dir=resolved_project_dir, + ) + try: + settings = SAXSProjectManager().load_project(resolved_project_dir) + except Exception: + return item + frames_dir = _frames_dir_for_project_settings(settings) + return replace( + item, + frames_dir=frames_dir, + frames_source_kind=_source_kind_for_project_settings(settings), + output_dir=( + None + if frames_dir is None + else suggest_cluster_output_dir(frames_dir) + ), + ) + + +class ClusterBatchItemWidget(QFrame): + settings_changed = Signal(str) + remove_requested = Signal(str) + duplicate_requested = Signal(str) + + def __init__( + self, + item: ClusterBatchItem, + *, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._item = item + self._loading = False + self._selected = False + self._last_summary: dict[str, object] | None = None + self._last_suggested_output_dir: Path | None = None + self._build_ui() + self._load_item(item) + self._set_settings_visible(False) + + @property + def item_id(self) -> str: + return self._item.item_id + + def item(self) -> ClusterBatchItem: + return self._item + + def collect_item(self) -> ClusterBatchItem: + self._item = ClusterBatchItem( + item_id=self._item.item_id, + project_dir=_optional_path(self.project_dir_edit.text()), + frames_dir=_optional_path(self.frames_dir_edit.text()), + frames_source_kind=self._item.frames_source_kind, + output_dir=_optional_path(self.output_dir_edit.text()), + atom_type_definitions=_copy_atom_type_definitions( + self.definitions_panel.atom_type_definitions() + ), + pair_cutoff_definitions=_copy_pair_cutoff_definitions( + self.definitions_panel.pair_cutoff_definitions() + ), + box_dimensions=self.definitions_panel.box_dimensions(), + use_pbc=self.definitions_panel.use_pbc(), + search_mode=self.definitions_panel.search_mode(), + save_state_frequency=self.definitions_panel.save_state_frequency(), + default_cutoff=self.definitions_panel.default_cutoff(), + shell_levels=self.definitions_panel.shell_growth_levels(), + include_shell_levels=self.definitions_panel.include_shell_levels(), + shared_shells=self.definitions_panel.shared_shells(), + smart_solvation_shells=( + self.definitions_panel.smart_solvation_shells() + ), + include_shell_atoms_in_stoichiometry=( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + ) + self._refresh_header() + self._refresh_project_reference() + return self._item + + def job(self) -> ClusterBatchJob: + self.collect_item() + project_dir = _required_project_dir(self.project_dir_edit.text()) + frames_dir = _required_frames_dir(self.frames_dir_edit.text()) + output_dir = _optional_path( + self.output_dir_edit.text() + ) or suggest_cluster_output_dir(frames_dir) + atom_type_definitions = self.definitions_panel.atom_type_definitions() + if not atom_type_definitions: + raise ValueError( + "Add at least one atom-type definition before exporting." + ) + if not ( + atom_type_definitions.get("node") + or atom_type_definitions.get("linker") + ): + raise ValueError("Define at least one node or linker atom type.") + pair_cutoffs = self.definitions_panel.pair_cutoff_definitions() + default_cutoff = self.definitions_panel.default_cutoff() + if not pair_cutoffs and default_cutoff is None: + raise ValueError( + "Add at least one pair-cutoff definition or specify a " + "default cutoff." + ) + + summary = self._last_summary + if summary is None: + summary = self._inspect_frames(frames_dir) + frame_format = str(summary.get("frame_format", "")) + box_dimensions = self.definitions_panel.box_dimensions() + if self.definitions_panel.use_pbc() and box_dimensions is None: + box_dimensions = _summary_box_dimensions(summary) + if box_dimensions is None: + raise ValueError( + "Periodic boundary conditions are enabled, but no box " + "dimensions are available." + ) + + config = ClusterJobConfig( + frames_dir=frames_dir, + atom_type_definitions=atom_type_definitions, + pair_cutoff_definitions=pair_cutoffs, + box_dimensions=box_dimensions, + use_pbc=self.definitions_panel.use_pbc(), + search_mode=self.definitions_panel.search_mode(), + save_state_frequency=self.definitions_panel.save_state_frequency(), + default_cutoff=default_cutoff, + shell_levels=self.definitions_panel.shell_growth_levels(), + include_shell_levels=self.definitions_panel.include_shell_levels(), + shared_shells=self.definitions_panel.shared_shells(), + smart_solvation_shells=( + frame_format == "pdb" + and self.definitions_panel.smart_solvation_shells() + ), + include_shell_atoms_in_stoichiometry=( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + output_dir=output_dir, + ) + return ClusterBatchJob( + project_dir=project_dir, + frames_dir=frames_dir, + frames_source_kind=self._item.frames_source_kind, + config=config, + ) + + def set_locked(self, locked: bool) -> None: + self.settings_group.setEnabled(not locked) + self.inspect_button.setEnabled(not locked) + self.duplicate_button.setEnabled(not locked) + self.remove_button.setEnabled(not locked) + + def set_status(self, message: str) -> None: + self.status_label.setText(message) + + def set_progress(self, processed: int, total: int) -> None: + self.progress_bar.setRange(0, max(int(total), 1)) + self.progress_bar.setValue(max(int(processed), 0)) + + def set_selected(self, selected: bool) -> None: + self._selected = bool(selected) + self.header_frame.setProperty("selected", self._selected) + self.header_frame.setStyleSheet( + "QFrame#ClusterBatchItemHeader {" + + ( + "background-color: #dce8f7; " "border: 1px solid #8fb0d7;" + if self._selected + else "background-color: #f6f8fb; " "border: 1px solid #cfd7e3;" + ) + + "border-radius: 5px;}" + ) + + def analyze_input(self) -> None: + frames_dir = _required_frames_dir(self.frames_dir_edit.text()) + summary = self._inspect_frames(frames_dir) + self._apply_summary(summary) + self.set_progress(0, max(int(summary.get("n_frames", 1)), 1)) + self.set_status("Input inspected") + + def _build_ui(self) -> None: + self.setFrameShape(QFrame.Shape.StyledPanel) + self.setSizePolicy( + QSizePolicy.Policy.Expanding, + QSizePolicy.Policy.Fixed, + ) + root = QVBoxLayout(self) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + self.header_frame = QFrame() + self.header_frame.setObjectName("ClusterBatchItemHeader") + header = QHBoxLayout(self.header_frame) + header.setContentsMargins(8, 6, 8, 6) + header.setSpacing(8) + self.toggle_button = QToolButton() + self.toggle_button.setCheckable(True) + self.toggle_button.toggled.connect(self._set_settings_visible) + header.addWidget(self.toggle_button) + self.title_label = QLabel("New cluster extraction") + self.title_label.setStyleSheet("font-weight: 600;") + header.addWidget(self.title_label, stretch=1) + self.status_label = QLabel("Ready") + self.status_label.setMinimumWidth(180) + header.addWidget(self.status_label) + self.inspect_button = QPushButton("Inspect") + self.inspect_button.clicked.connect(self._inspect_from_button) + header.addWidget(self.inspect_button) + self.duplicate_button = QPushButton("Duplicate") + self.duplicate_button.clicked.connect( + lambda: self.duplicate_requested.emit(self.item_id) + ) + header.addWidget(self.duplicate_button) + self.remove_button = QPushButton("Remove") + self.remove_button.clicked.connect( + lambda: self.remove_requested.emit(self.item_id) + ) + header.addWidget(self.remove_button) + root.addWidget(self.header_frame) + self.set_selected(False) + + self.progress_bar = QProgressBar() + self.progress_bar.setRange(0, 1) + self.progress_bar.setValue(0) + self.progress_bar.setFormat("%v / %m frames") + root.addWidget(self.progress_bar) + + self.settings_group = QGroupBox("Cluster Extraction Settings") + root.addWidget(self.settings_group) + settings_layout = QVBoxLayout(self.settings_group) + + form = QFormLayout() + self.project_dir_edit = QLineEdit() + self.project_dir_edit.editingFinished.connect(self._on_project_changed) + form.addRow( + "Project folder", + self._path_row(self.project_dir_edit, self._choose_project_dir), + ) + self.project_reference_label = QLabel() + self.project_reference_label.setWordWrap(True) + self.project_reference_label.setFrameShape(QFrame.Shape.StyledPanel) + form.addRow("", self.project_reference_label) + + self.frames_dir_edit = QLineEdit() + self.frames_dir_edit.editingFinished.connect(self._on_frames_changed) + form.addRow( + "Frames folder", + self._path_row(self.frames_dir_edit, self._choose_frames_dir), + ) + + self.output_dir_edit = QLineEdit() + self.output_dir_edit.editingFinished.connect(self._on_editor_changed) + form.addRow( + "Output folder", + self._path_row(self.output_dir_edit, self._choose_output_dir), + ) + settings_layout.addLayout(form) + + self.summary_box = QTextEdit() + self.summary_box.setReadOnly(True) + self.summary_box.setMinimumHeight(120) + self.summary_box.setPlainText( + "Inspect the frames folder to detect PDB/XYZ mode and box " + "settings." + ) + settings_layout.addWidget(self.summary_box) + + self.definitions_panel = ClusterDefinitionsPanel() + self.definitions_panel.load_preset( + DEFAULT_CLUSTER_EXTRACTION_PRESET_NAME + ) + self.definitions_panel.settings_changed.connect( + self._on_editor_changed + ) + settings_layout.addWidget(self.definitions_panel) + + def _path_row(self, edit: QLineEdit, slot) -> QWidget: + row_widget = QWidget() + row = QHBoxLayout(row_widget) + row.setContentsMargins(0, 0, 0, 0) + row.addWidget(edit, stretch=1) + button = QPushButton("Browse...") + button.clicked.connect(slot) + row.addWidget(button) + return row_widget + + def _load_item(self, item: ClusterBatchItem) -> None: + self._loading = True + self.project_dir_edit.setText( + "" if item.project_dir is None else str(item.project_dir) + ) + self.frames_dir_edit.setText( + "" if item.frames_dir is None else str(item.frames_dir) + ) + self.output_dir_edit.setText( + "" if item.output_dir is None else str(item.output_dir) + ) + if item.atom_type_definitions or item.pair_cutoff_definitions: + self.definitions_panel.load_atom_type_definitions( + item.atom_type_definitions, + emit_signal=False, + ) + self.definitions_panel.load_pair_cutoff_definitions( + item.pair_cutoff_definitions, + emit_signal=False, + ) + self.definitions_panel.set_use_pbc(item.use_pbc, emit_signal=False) + self.definitions_panel.set_search_mode( + item.search_mode, + emit_signal=False, + ) + self.definitions_panel.set_save_state_frequency( + item.save_state_frequency, + emit_signal=False, + ) + self.definitions_panel.set_default_cutoff( + item.default_cutoff, + emit_signal=False, + ) + self.definitions_panel.set_shell_growth_levels( + item.shell_levels, + emit_signal=False, + ) + self.definitions_panel.set_shared_shells( + item.shared_shells, + emit_signal=False, + ) + self.definitions_panel.set_smart_solvation_shells( + item.smart_solvation_shells, + emit_signal=False, + ) + self.definitions_panel.set_include_shell_atoms_in_stoichiometry( + item.include_shell_atoms_in_stoichiometry, + emit_signal=False, + ) + self.definitions_panel.set_box_dimensions( + item.box_dimensions, + emit_signal=False, + ) + self._loading = False + self._refresh_header() + self._refresh_project_reference() + self._analyze_quietly() + + def _set_settings_visible(self, visible: bool) -> None: + self.settings_group.setVisible(bool(visible)) + self.toggle_button.setChecked(bool(visible)) + self.toggle_button.setText("Hide Settings" if visible else "Settings") + parent_item = self._list_item() + if parent_item is not None: + parent_item.setSizeHint(self.sizeHint()) + + def _list_item(self) -> QListWidgetItem | None: + parent = self.parent() + while parent is not None and not isinstance(parent, QListWidget): + parent = parent.parent() + if not isinstance(parent, QListWidget): + return None + for row in range(parent.count()): + list_item = parent.item(row) + if parent.itemWidget(list_item) is self: + return list_item + return None + + def _choose_project_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select SAXSShell project folder", + _dialog_start_dir(self.project_dir_edit.text()), + ) + if not selected: + return + self._load_item( + replace( + _queue_item_from_project_defaults( + selected, + item_id=self.item_id, + ), + atom_type_definitions=self.definitions_panel.atom_type_definitions(), + pair_cutoff_definitions=( + self.definitions_panel.pair_cutoff_definitions() + ), + box_dimensions=self.definitions_panel.box_dimensions(), + use_pbc=self.definitions_panel.use_pbc(), + search_mode=self.definitions_panel.search_mode(), + save_state_frequency=self.definitions_panel.save_state_frequency(), + default_cutoff=self.definitions_panel.default_cutoff(), + shell_levels=self.definitions_panel.shell_growth_levels(), + include_shell_levels=self.definitions_panel.include_shell_levels(), + shared_shells=self.definitions_panel.shared_shells(), + smart_solvation_shells=( + self.definitions_panel.smart_solvation_shells() + ), + include_shell_atoms_in_stoichiometry=( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + ) + ) + self._on_editor_changed() + + def _choose_frames_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select extracted PDB or XYZ frames folder", + _dialog_start_dir( + self.frames_dir_edit.text(), + self.project_dir_edit.text(), + ), + ) + if not selected: + return + self.frames_dir_edit.setText(selected) + self._on_frames_changed() + + def _choose_output_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select cluster output folder", + _dialog_start_dir( + self.output_dir_edit.text(), + self.frames_dir_edit.text(), + ), + ) + if not selected: + return + self.output_dir_edit.setText(selected) + self._on_editor_changed() + + def _on_project_changed(self) -> None: + project_dir = _optional_path(self.project_dir_edit.text()) + if project_dir is None: + self._on_editor_changed() + return + try: + item = _queue_item_from_project_defaults( + project_dir, + item_id=self.item_id, + ) + except Exception: + self._on_editor_changed() + return + if item.frames_dir is not None: + self.frames_dir_edit.setText(str(item.frames_dir)) + if item.output_dir is not None: + self.output_dir_edit.setText(str(item.output_dir)) + self._item = replace( + self._item, + project_dir=project_dir, + frames_dir=item.frames_dir, + frames_source_kind=item.frames_source_kind, + output_dir=item.output_dir, + ) + self._analyze_quietly() + self._on_editor_changed() + + def _on_frames_changed(self) -> None: + frames_dir = _optional_path(self.frames_dir_edit.text()) + if frames_dir is not None: + suggested = suggest_cluster_output_dir(frames_dir) + current = _optional_path(self.output_dir_edit.text()) + if current is None or current == self._last_suggested_output_dir: + self.output_dir_edit.setText(str(suggested)) + self._last_suggested_output_dir = suggested + self._analyze_quietly() + self._on_editor_changed() + + def _inspect_from_button(self) -> None: + try: + self.analyze_input() + self._on_editor_changed() + except Exception as exc: + QMessageBox.warning( + self, + "Unable to inspect frames folder", + str(exc), + ) + self.summary_box.setPlainText(str(exc)) + self.set_status("Inspection failed") + self._on_editor_changed() + + def _analyze_quietly(self) -> None: + if not self.frames_dir_edit.text().strip(): + return + try: + self.analyze_input() + except Exception as exc: + self.summary_box.setPlainText(str(exc)) + self.set_status("Inspection failed") + + def _inspect_frames(self, frames_dir: Path) -> dict[str, object]: + analyzer = ExtractedFrameFolderClusterAnalyzer( + frames_dir=frames_dir, + atom_type_definitions={}, + pair_cutoffs_def={}, + ) + return analyzer.inspect() + + def _apply_summary(self, summary: dict[str, object]) -> None: + self._last_summary = summary + frame_format = str(summary.get("frame_format", "") or "") + self.definitions_panel.set_frame_mode(frame_format) + if summary.get("box_dimensions_source_kind") == "source_filename": + box_dimensions = _summary_box_dimensions(summary) + if box_dimensions is not None: + self.definitions_panel.set_box_dimensions( + box_dimensions, + emit_signal=False, + ) + self.summary_box.setPlainText(_summary_text(summary)) + self.set_status( + f"{frame_folder_label(frame_format)} mode, " + f"{int(summary.get('n_frames', 0))} frame(s)" + ) + + def _on_editor_changed(self, *_args) -> None: + if self._loading: + return + try: + self.collect_item() + if self.status_label.text() in {"Inspection failed", "Failed"}: + self.set_status("Ready") + except Exception: + self._refresh_header() + self._refresh_project_reference() + self.settings_changed.emit(self.item_id) + + def _refresh_header(self) -> None: + self.title_label.setText(self._item.display_name()) + + def _refresh_project_reference(self) -> None: + project_dir = _optional_path(self.project_dir_edit.text()) + if project_dir is None: + text = "Project reference: choose a SAXSShell project folder." + else: + project_file = build_project_paths(project_dir).project_file + if project_file.is_file(): + text = f"Project reference: {project_file}" + else: + text = f"Project reference: no project file found at {project_file}" + self.project_reference_label.setText(text) + + +class ClusterBatchWorker(QObject): + item_started = Signal(str, int, int) + item_progress = Signal(str, int, int, str) + item_phase_changed = Signal(str, str) + item_finished = Signal(str, object) + item_failed = Signal(str, str) + log = Signal(str) + status = Signal(str) + finished = Signal(object) + failed = Signal(str, str) + + def __init__( + self, + queue_entries: list[tuple[str, ClusterBatchJob]], + ) -> None: + super().__init__() + self.queue_entries = list(queue_entries) + self._cancel_requested = threading.Event() + self._project_manager = SAXSProjectManager() + + def request_cancel(self) -> None: + self._cancel_requested.set() + + @Slot() + def run(self) -> None: + results: list[ClusterBatchResult] = [] + total_items = len(self.queue_entries) + for index, (item_id, job) in enumerate( + self.queue_entries, + start=1, + ): + if self._cancel_requested.is_set(): + self.log.emit("Batch queue stopped before the next project.") + break + self.item_started.emit(item_id, index, total_items) + self.status.emit( + f"Running {index}/{total_items}: {job.project_dir.name}" + ) + self.log.emit(f"Starting {index}/{total_items}: {job.project_dir}") + try: + result = self._run_job(item_id, job) + except Exception as exc: + message = str(exc) + self.item_failed.emit(item_id, message) + self.failed.emit(item_id, message) + return + results.append(result) + self.item_finished.emit(item_id, result) + self.status.emit("Cluster extraction batch queue finished") + self.finished.emit(results) + + def _run_job( + self, + item_id: str, + job: ClusterBatchJob, + ) -> ClusterBatchResult: + worker = ClusterExportWorker(job.config) + results: list[ClusterExportResult] = [] + failures: list[str] = [] + worker.progress.connect( + lambda message: self.log.emit( + f"[{job.project_dir.name}] {message}" + ) + ) + worker.phase_changed.connect( + lambda phase: self.item_phase_changed.emit(item_id, phase) + ) + worker.progress_count.connect( + lambda processed, total: self.item_progress.emit( + item_id, + processed, + total, + f"{processed}/{max(total, 1)} frame(s)", + ) + ) + worker.finished.connect(results.append) + worker.failed.connect(failures.append) + worker.run() + if failures: + raise RuntimeError(failures[0]) + if not results: + raise RuntimeError("Cluster extraction did not return a result.") + export_result = results[0] + settings = self._project_manager.load_project(job.project_dir) + settings.clusters_dir = str( + export_result.output_dir.expanduser().resolve() + ) + self._project_manager.save_project(settings) + self.log.emit( + f"[{job.project_dir.name}] Registered clusters folder: " + f"{settings.clusters_dir}" + ) + return ClusterBatchResult( + project_dir=job.project_dir, + frames_dir=job.frames_dir, + frames_source_kind=job.frames_source_kind, + output_dir=export_result.output_dir.expanduser().resolve(), + analyzed_frames=export_result.analyzed_frames, + total_clusters=export_result.total_clusters, + written_count=len(export_result.written_files), + ) + + +class ClusterBatchQueueWindow(QMainWindow): + """Queue cluster extractions for multiple projects.""" + + project_paths_registered = Signal(object) + + def __init__( + self, + initial_project_dir: str | Path | None = None, + *, + initial_frames_dir: str | Path | None = None, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._widgets_by_id: dict[str, ClusterBatchItemWidget] = {} + self._run_thread: QThread | None = None + self._run_worker: ClusterBatchWorker | None = None + self._initial_project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + self._initial_frames_dir = ( + None + if initial_frames_dir is None + else Path(initial_frames_dir).expanduser().resolve() + ) + self._build_ui() + if ( + self._initial_project_dir is not None + or self._initial_frames_dir is not None + ): + self._add_current_project() + + def closeEvent(self, event) -> None: + if self._run_thread is not None and self._run_thread.isRunning(): + self._request_cancel() + self.hide() + while ( + self._run_thread is not None and self._run_thread.isRunning() + ): + QApplication.processEvents() + if self._run_thread is not None: + self._run_thread.wait(50) + event.accept() + return + super().closeEvent(event) + + def add_queue_item( + self, + item: ClusterBatchItem | None = None, + ) -> ClusterBatchItemWidget: + resolved_item = item or ClusterBatchItem(item_id=_new_item_id()) + list_item = QListWidgetItem() + list_item.setData(Qt.ItemDataRole.UserRole, resolved_item.item_id) + self.queue_list.addItem(list_item) + widget = ClusterBatchItemWidget( + resolved_item, + parent=self.queue_list, + ) + widget.settings_changed.connect(self._on_item_settings_changed) + widget.remove_requested.connect(self._remove_item) + widget.duplicate_requested.connect(self._duplicate_item) + self._widgets_by_id[resolved_item.item_id] = widget + list_item.setSizeHint(widget.sizeHint()) + self.queue_list.setItemWidget(list_item, widget) + self.queue_list.setCurrentItem(list_item) + self._refresh_order_labels() + return widget + + def queue_jobs_in_order(self) -> list[tuple[str, ClusterBatchJob]]: + entries: list[tuple[str, ClusterBatchJob]] = [] + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id[item_id] + entries.append((item_id, widget.job())) + return entries + + def _build_ui(self) -> None: + self.setWindowTitle("SAXSShell Cluster Extraction Batch Queue") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1180, 880) + + central = QWidget() + root = QVBoxLayout(central) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + controls = QHBoxLayout() + self.add_current_button = QPushButton("Add Current Project") + self.add_current_button.clicked.connect(self._add_current_project) + controls.addWidget(self.add_current_button) + self.add_project_button = QPushButton("Add Projects...") + self.add_project_button.clicked.connect(self._choose_projects_to_add) + controls.addWidget(self.add_project_button) + controls.addStretch(1) + root.addLayout(controls) + + self.queue_list = QListWidget() + self.queue_list.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.queue_list.setDragDropMode( + QAbstractItemView.DragDropMode.InternalMove + ) + self.queue_list.setDefaultDropAction(Qt.DropAction.MoveAction) + self.queue_list.setAlternatingRowColors(True) + self.queue_list.setStyleSheet( + "QListWidget::item:selected { background: transparent; }" + "QListWidget::item:hover { background: transparent; }" + "QListWidget::item { margin: 3px; }" + ) + self.queue_list.model().rowsMoved.connect(self._refresh_order_labels) + self.queue_list.itemSelectionChanged.connect( + self._refresh_item_selection_styles + ) + root.addWidget(self.queue_list, stretch=1) + + run_group = QGroupBox("Execute Queue") + run_layout = QVBoxLayout(run_group) + run_buttons = QHBoxLayout() + self.run_button = QPushButton("Run Complete Queue") + self.run_button.clicked.connect(self._start_queue) + run_buttons.addWidget(self.run_button) + self.cancel_button = QPushButton("Stop Queue") + self.cancel_button.setEnabled(False) + self.cancel_button.clicked.connect(self._request_cancel) + run_buttons.addWidget(self.cancel_button) + run_buttons.addStretch(1) + run_layout.addLayout(run_buttons) + self.queue_status_label = QLabel("Queue idle") + run_layout.addWidget(self.queue_status_label) + self.console = QTextEdit() + self.console.setReadOnly(True) + self.console.setMinimumHeight(160) + run_layout.addWidget(self.console) + root.addWidget(run_group) + + self.setCentralWidget(central) + self.statusBar().showMessage("Ready") + + def _add_current_project(self) -> None: + if ( + self._initial_project_dir is None + and self._initial_frames_dir is None + ): + QMessageBox.information( + self, + "No active project", + "The main UI did not provide an active project reference.", + ) + return + item = ( + _queue_item_from_project_defaults(self._initial_project_dir) + if self._initial_project_dir is not None + else ClusterBatchItem(item_id=_new_item_id()) + ) + if self._initial_frames_dir is not None: + item = replace( + item, + frames_dir=self._initial_frames_dir, + output_dir=suggest_cluster_output_dir( + self._initial_frames_dir + ), + ) + self.add_queue_item(item) + + def _choose_projects_to_add(self) -> None: + selected_dirs = _choose_existing_directories( + self, + title="Select SAXSShell project folders", + start_dir=self._initial_project_dir or Path.home(), + ) + if not selected_dirs: + return + for project_dir in selected_dirs: + self.add_queue_item(_queue_item_from_project_defaults(project_dir)) + + def _on_item_settings_changed(self, _item_id: str) -> None: + self._refresh_order_labels() + + def _refresh_order_labels(self, *_args) -> None: + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id.get(item_id) + if widget is None: + continue + widget.title_label.setText( + f"{row + 1}. {widget.item().display_name()}" + ) + list_item.setSizeHint(widget.sizeHint()) + self._refresh_item_selection_styles() + + def _refresh_item_selection_styles(self) -> None: + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_selected(list_item.isSelected()) + + def _remove_item(self, item_id: str) -> None: + if self._run_thread is not None and self._run_thread.isRunning(): + return + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + if str(list_item.data(Qt.ItemDataRole.UserRole)) == item_id: + self.queue_list.takeItem(row) + break + self._widgets_by_id.pop(item_id, None) + self._refresh_order_labels() + + def _duplicate_item(self, item_id: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is None: + return + try: + item = widget.collect_item() + except Exception: + item = widget.item() + self.add_queue_item(replace(item, item_id=_new_item_id())) + + def _set_running(self, running: bool) -> None: + self.add_current_button.setEnabled(not running) + self.add_project_button.setEnabled(not running) + self.run_button.setEnabled(not running) + self.cancel_button.setEnabled(running) + self.queue_list.setDragEnabled(not running) + self.queue_list.setAcceptDrops(not running) + for widget in self._widgets_by_id.values(): + widget.set_locked(running) + + def _start_queue(self) -> None: + if self.queue_list.count() == 0: + QMessageBox.information( + self, + "Cluster extraction batch queue", + "Add at least one project before running the queue.", + ) + return + try: + entries = self.queue_jobs_in_order() + except Exception as exc: + QMessageBox.warning( + self, + "Invalid cluster extraction batch settings", + str(exc), + ) + return + + self.console.clear() + self._set_running(True) + self.queue_status_label.setText( + f"Running 0/{len(entries)} queued extraction(s)" + ) + for widget in self._widgets_by_id.values(): + widget.set_progress(0, 1) + widget.set_status("Queued") + + self._run_thread = QThread(self) + self._run_worker = ClusterBatchWorker(entries) + self._run_worker.moveToThread(self._run_thread) + self._run_thread.started.connect(self._run_worker.run) + self._run_worker.item_started.connect(self._on_item_started) + self._run_worker.item_progress.connect(self._on_item_progress) + self._run_worker.item_phase_changed.connect( + self._on_item_phase_changed + ) + self._run_worker.item_finished.connect(self._on_item_finished) + self._run_worker.item_failed.connect(self._on_item_failed) + self._run_worker.log.connect(self._append_log) + self._run_worker.status.connect(self._on_status) + self._run_worker.finished.connect(self._on_queue_finished) + self._run_worker.failed.connect(self._on_queue_failed) + self._run_worker.finished.connect(self._run_thread.quit) + self._run_worker.failed.connect(self._run_thread.quit) + self._run_thread.finished.connect(self._cleanup_run_thread) + self._run_thread.finished.connect(self._run_thread.deleteLater) + self._run_thread.start() + + def _request_cancel(self) -> None: + self.cancel_button.setEnabled(False) + self.queue_status_label.setText( + "Stopping queue after the active project finishes" + ) + self._append_log( + "Stop requested; the current project will finish before the " + "queue exits." + ) + if self._run_worker is not None: + self._run_worker.request_cancel() + + def _append_log(self, message: str) -> None: + self.console.append(message) + + def _on_status(self, message: str) -> None: + self.statusBar().showMessage(message) + self.queue_status_label.setText(message) + + def _on_item_started( + self, + item_id: str, + index: int, + total: int, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status(f"Running {index}/{total}") + widget.set_progress(0, 1) + self.queue_status_label.setText( + f"Running {index}/{total} queued extraction(s)" + ) + + def _on_item_progress( + self, + item_id: str, + processed: int, + total: int, + message: str, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_progress(processed, total) + widget.set_status(message) + + def _on_item_phase_changed(self, item_id: str, phase: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status( + "Sorting clusters" if phase == "sorting" else "Extracting" + ) + + def _on_item_finished( + self, + item_id: str, + result: ClusterBatchResult, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is None: + return + widget.set_progress( + result.analyzed_frames, + max(result.analyzed_frames, 1), + ) + widget.set_status("Complete") + self.project_paths_registered.emit( + { + "project_dir": result.project_dir, + "clusters_dir": result.output_dir, + } + ) + + def _on_item_failed(self, item_id: str, message: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status("Failed") + self._append_log(message) + + def _on_queue_finished(self, results: object) -> None: + self._set_running(False) + result_count = len(results) if isinstance(results, list) else 0 + self.queue_status_label.setText( + f"Queue finished: {result_count} extraction(s) saved" + ) + self.statusBar().showMessage("Cluster extraction batch queue finished") + + def _on_queue_failed(self, item_id: str, message: str) -> None: + self._set_running(False) + self.queue_status_label.setText("Queue stopped after a failure") + self.statusBar().showMessage( + "Cluster extraction batch queue failed", + 5000, + ) + QMessageBox.warning( + self, + "Cluster extraction batch queue failed", + f"Queue item {item_id} failed:\n{message}", + ) + + def _cleanup_run_thread(self) -> None: + self._run_thread = None + self._run_worker = None + + +def launch_cluster_batch_queue_ui( + initial_project_dir: str | Path | None = None, + *, + initial_frames_dir: str | Path | None = None, +) -> int: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication([]) + configure_saxshell_application(app) + window = ClusterBatchQueueWindow( + initial_project_dir=initial_project_dir, + initial_frames_dir=initial_frames_dir, + ) + window.show() + return int(app.exec()) + + +__all__ = [ + "ClusterBatchItem", + "ClusterBatchItemWidget", + "ClusterBatchJob", + "ClusterBatchQueueWindow", + "ClusterBatchResult", + "ClusterBatchWorker", + "launch_cluster_batch_queue_ui", +] diff --git a/src/saxshell/cluster/ui/run_file_window.py b/src/saxshell/cluster/ui/run_file_window.py new file mode 100644 index 0000000..95001d7 --- /dev/null +++ b/src/saxshell/cluster/ui/run_file_window.py @@ -0,0 +1,482 @@ +from __future__ import annotations + +import json +import sys +from pathlib import Path + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QApplication, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QMainWindow, + QMessageBox, + QPlainTextEdit, + QPushButton, + QScrollArea, + QSplitter, + QVBoxLayout, + QWidget, +) + +from saxshell.cluster import ( + DEFAULT_CLUSTER_EXTRACTION_PRESET_NAME, + ClusterWorkflow, + format_box_dimensions, +) +from saxshell.cluster.run_config import ( + build_cluster_run_config, + default_cluster_run_file_path, + preview_cluster_run_config, + save_cluster_run_config, + suggest_run_config_output_dir, +) +from saxshell.cluster.ui.definitions_panel import ClusterDefinitionsPanel + + +class ClusterRunFileWindow(QMainWindow): + def __init__( + self, + *, + initial_project_dir: str | Path | None = None, + initial_frames_dir: str | Path | None = None, + ) -> None: + super().__init__() + self._browse_start_dir = Path.home() + self._last_suggested_output_dir: str | None = None + self._last_summary: dict[str, object] | None = None + + project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + frames_dir = ( + None + if initial_frames_dir is None + else Path(initial_frames_dir).expanduser().resolve() + ) + if project_dir is not None: + self._browse_start_dir = project_dir + if frames_dir is None: + frames_dir = self._project_frames_dir(project_dir) + + self.setWindowTitle("Cluster Extraction CLI Setup (Beta)") + self.setWindowIcon(_load_saxshell_icon()) + self.resize(1100, 780) + self._build_ui() + self.definitions_panel.load_preset( + DEFAULT_CLUSTER_EXTRACTION_PRESET_NAME + ) + + if project_dir is not None: + self.project_dir_edit.setText(str(project_dir)) + self._refresh_run_file_path() + if frames_dir is not None and frames_dir.is_dir(): + self.frames_dir_edit.setText(str(frames_dir)) + self._browse_start_dir = frames_dir + self._inspect_frames() + self._update_preview() + + def _build_ui(self) -> None: + central = QWidget(self) + root = QVBoxLayout(central) + root.setContentsMargins(10, 10, 10, 10) + root.setSpacing(8) + self.setCentralWidget(central) + + splitter = QSplitter(Qt.Orientation.Horizontal, self) + splitter.setChildrenCollapsible(False) + root.addWidget(splitter, stretch=1) + + left_scroll = QScrollArea(self) + left_scroll.setWidgetResizable(True) + left_panel = QWidget() + self.left_layout = QVBoxLayout(left_panel) + self.left_layout.setContentsMargins(10, 10, 10, 10) + self.left_layout.setSpacing(10) + left_scroll.setWidget(left_panel) + + right_scroll = QScrollArea(self) + right_scroll.setWidgetResizable(True) + right_panel = QWidget() + self.right_layout = QVBoxLayout(right_panel) + self.right_layout.setContentsMargins(10, 10, 10, 10) + self.right_layout.setSpacing(10) + right_scroll.setWidget(right_panel) + + splitter.addWidget(left_scroll) + splitter.addWidget(right_scroll) + splitter.setSizes([560, 540]) + + self.left_layout.addWidget(self._build_project_group()) + self.left_layout.addWidget(self._build_frames_group()) + self.definitions_panel = ClusterDefinitionsPanel() + self.definitions_panel.settings_changed.connect(self._update_preview) + self.left_layout.addWidget(self.definitions_panel) + self.left_layout.addWidget(self._build_save_group()) + self.left_layout.addStretch(1) + + self.right_layout.addWidget(self._build_inspection_group()) + self.right_layout.addWidget(self._build_command_group()) + self.right_layout.addStretch(1) + self.statusBar().showMessage("Ready") + + def _build_project_group(self) -> QGroupBox: + group = QGroupBox("Project") + form = QFormLayout(group) + project_row = QHBoxLayout() + self.project_dir_edit = QLineEdit() + self.project_dir_edit.editingFinished.connect( + self._on_project_dir_changed + ) + project_row.addWidget(self.project_dir_edit, stretch=1) + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self._browse_project_dir) + project_row.addWidget(browse_button) + project_widget = QWidget() + project_widget.setLayout(project_row) + form.addRow("Project folder", project_widget) + + self.run_file_edit = QLineEdit() + self.run_file_edit.setReadOnly(True) + form.addRow("Run file", self.run_file_edit) + return group + + def _build_frames_group(self) -> QGroupBox: + group = QGroupBox("Input / Output") + form = QFormLayout(group) + frames_row = QHBoxLayout() + self.frames_dir_edit = QLineEdit() + self.frames_dir_edit.editingFinished.connect(self._inspect_frames) + frames_row.addWidget(self.frames_dir_edit, stretch=1) + frames_button = QPushButton("Browse...") + frames_button.clicked.connect(self._browse_frames_dir) + frames_row.addWidget(frames_button) + frames_widget = QWidget() + frames_widget.setLayout(frames_row) + form.addRow("Frames folder", frames_widget) + + output_row = QHBoxLayout() + self.output_dir_edit = QLineEdit() + self.output_dir_edit.editingFinished.connect(self._update_preview) + output_row.addWidget(self.output_dir_edit, stretch=1) + output_button = QPushButton("Browse...") + output_button.clicked.connect(self._browse_output_dir) + output_row.addWidget(output_button) + output_widget = QWidget() + output_widget.setLayout(output_row) + form.addRow("Output clusters folder", output_widget) + return group + + def _build_save_group(self) -> QGroupBox: + group = QGroupBox("Save") + layout = QHBoxLayout(group) + inspect_button = QPushButton("Inspect Frames") + inspect_button.clicked.connect(self._inspect_frames) + layout.addWidget(inspect_button) + save_button = QPushButton("Save Run File") + save_button.clicked.connect(self._save_run_file) + layout.addWidget(save_button) + layout.addStretch(1) + return group + + def _build_inspection_group(self) -> QGroupBox: + group = QGroupBox("Inspection") + layout = QVBoxLayout(group) + self.inspection_box = QPlainTextEdit() + self.inspection_box.setReadOnly(True) + self.inspection_box.setMinimumHeight(210) + layout.addWidget(self.inspection_box) + return group + + def _build_command_group(self) -> QGroupBox: + group = QGroupBox("CLI Command / JSON") + layout = QVBoxLayout(group) + layout.addWidget(QLabel("Commands")) + self.command_box = QPlainTextEdit() + self.command_box.setReadOnly(True) + self.command_box.setMinimumHeight(130) + layout.addWidget(self.command_box) + layout.addWidget(QLabel("Run file preview")) + self.json_preview_box = QPlainTextEdit() + self.json_preview_box.setReadOnly(True) + self.json_preview_box.setMinimumHeight(300) + layout.addWidget(self.json_preview_box) + return group + + def _browse_project_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select project folder", + str(self._browse_start_dir), + ) + if not selected: + return + self.project_dir_edit.setText(selected) + self._on_project_dir_changed() + + def _browse_frames_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select extracted frames folder", + str(self._browse_start_dir), + ) + if not selected: + return + self.frames_dir_edit.setText(selected) + self._browse_start_dir = Path(selected).expanduser().resolve() + self._inspect_frames() + + def _browse_output_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select output clusters folder", + self.output_dir_edit.text().strip() or str(self._browse_start_dir), + ) + if selected: + self.output_dir_edit.setText(selected) + self._update_preview() + + def _on_project_dir_changed(self, *_args: object) -> None: + project_dir = self._project_dir() + if project_dir is None: + return + self._browse_start_dir = project_dir + self._refresh_run_file_path() + if not self.frames_dir_edit.text().strip(): + frames_dir = self._project_frames_dir(project_dir) + if frames_dir is not None and frames_dir.is_dir(): + self.frames_dir_edit.setText(str(frames_dir)) + self._inspect_frames() + + def _inspect_frames(self, *_args: object) -> None: + frames_text = self.frames_dir_edit.text().strip() + if not frames_text: + self._last_summary = None + self.inspection_box.setPlainText("No frames folder selected.") + self._update_preview() + return + try: + workflow = ClusterWorkflow( + frames_dir=frames_text, + atom_type_definitions={}, + pair_cutoff_definitions={}, + ) + summary = workflow.inspect() + except Exception as exc: + self._last_summary = None + self.inspection_box.setPlainText(str(exc)) + self.statusBar().showMessage("Frames inspection failed") + self._update_preview() + return + self._last_summary = summary + self.definitions_panel.set_frame_mode( + str(summary.get("frame_format", "") or "") + ) + self.inspection_box.setPlainText(_summary_text(summary)) + self._refresh_suggested_output_dir() + self.statusBar().showMessage( + f"Discovered {int(summary.get('n_frames', 0))} frame(s)" + ) + self._update_preview() + + def _refresh_run_file_path(self) -> None: + project_dir = self._project_dir() + if project_dir is None: + self.run_file_edit.clear() + return + self.run_file_edit.setText( + str(default_cluster_run_file_path(project_dir)) + ) + + def _refresh_suggested_output_dir(self) -> None: + frames_text = self.frames_dir_edit.text().strip() + if not frames_text: + return + try: + suggested = suggest_run_config_output_dir(frames_dir=frames_text) + except Exception: + return + current = self.output_dir_edit.text().strip() + if not current or current == self._last_suggested_output_dir: + self.output_dir_edit.setText(str(suggested)) + self._last_suggested_output_dir = str(suggested) + + def _save_run_file(self, *_args: object) -> None: + try: + project_dir = self._require_project_dir() + config = self._current_config(project_dir) + except Exception as exc: + QMessageBox.warning(self, "Cluster CLI Setup", str(exc)) + return + run_file_path = default_cluster_run_file_path(project_dir) + save_cluster_run_config(run_file_path, config) + self.run_file_edit.setText(str(run_file_path)) + self.json_preview_box.setPlainText(save_preview_text(config.to_dict())) + self._update_preview() + self.statusBar().showMessage(f"Saved run file: {run_file_path}") + QMessageBox.information( + self, + "Cluster CLI Setup", + f"Saved cluster extraction CLI run file:\n{run_file_path}", + ) + + def _update_preview(self, *_args: object) -> None: + project_dir = self._project_dir() + if project_dir is None: + self.command_box.setPlainText( + "Select a project folder before saving the CLI run file." + ) + self.json_preview_box.clear() + return + self._refresh_run_file_path() + self.command_box.setPlainText( + f'clusters run "{project_dir}"\n' + f'saxshell cluster run "{project_dir}"' + ) + try: + config = self._current_config(project_dir) + payload = config.to_dict() + try: + payload["selection_preview"] = preview_cluster_run_config( + project_dir=project_dir, + config=config, + ) + except Exception as exc: + payload["selection_preview_error"] = str(exc) + except Exception as exc: + self.json_preview_box.setPlainText(str(exc)) + return + self.json_preview_box.setPlainText(save_preview_text(payload)) + + def _current_config(self, project_dir: Path): + frames_text = self.frames_dir_edit.text().strip() + if not frames_text: + raise ValueError("Choose a frames folder before saving.") + output_text = self.output_dir_edit.text().strip() + return build_cluster_run_config( + project_dir=project_dir, + frames_dir=frames_text, + output_dir=output_text or None, + atom_type_definitions=self.definitions_panel.atom_type_definitions(), + pair_cutoff_definitions=( + self.definitions_panel.pair_cutoff_definitions() + ), + box_dimensions=self.definitions_panel.box_dimensions(), + use_pbc=self.definitions_panel.use_pbc(), + default_cutoff=self.definitions_panel.default_cutoff(), + shell_levels=self.definitions_panel.shell_growth_levels(), + include_shell_levels=self.definitions_panel.include_shell_levels(), + shared_shells=self.definitions_panel.shared_shells(), + smart_solvation_shells=( + self.definitions_panel.smart_solvation_shells() + ), + include_shell_atoms_in_stoichiometry=( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + search_mode=self.definitions_panel.search_mode(), + save_state_frequency=self.definitions_panel.save_state_frequency(), + ) + + def _project_dir(self) -> Path | None: + text = self.project_dir_edit.text().strip() + if not text: + return None + return Path(text).expanduser().resolve() + + def _require_project_dir(self) -> Path: + project_dir = self._project_dir() + if project_dir is None: + raise ValueError("Choose a project folder before saving.") + if not project_dir.is_dir(): + raise ValueError(f"Project folder does not exist: {project_dir}") + return project_dir + + @staticmethod + def _project_frames_dir(project_dir: Path) -> Path | None: + try: + payload = json.loads( + (project_dir / "saxs_project.json").read_text(encoding="utf-8") + ) + except Exception: + return None + if not isinstance(payload, dict): + return None + frames_dir = _optional_project_path(payload.get("frames_dir")) + return frames_dir or _optional_project_path( + payload.get("pdb_frames_dir") + ) + + +def _summary_text(summary: dict[str, object]) -> str: + box_dimensions = summary.get("box_dimensions") + if box_dimensions is None: + box_dimensions = summary.get("estimated_box_dimensions") + source_kind = summary.get("box_dimensions_source_kind") + label = ( + "Source box dimensions" + if source_kind == "source_filename" + else "Estimated box dimensions" + ) + lines = [ + f"Frames folder: {summary.get('input_dir')}", + f"Mode: {summary.get('mode_label')}", + f"Frames: {summary.get('n_frames')}", + f"Output format: {summary.get('output_file_extension')}", + f"{label}: {format_box_dimensions(box_dimensions)}", + ] + if summary.get("box_dimensions_source") is not None: + lines.append(f"Box source: {summary.get('box_dimensions_source')}") + return "\n".join(lines) + + +def save_preview_text(payload: dict[str, object]) -> str: + return json.dumps(payload, indent=2) + + +def _optional_project_path(value: object) -> Path | None: + text = str(value or "").strip() + if not text: + return None + return Path(text).expanduser().resolve() + + +def _load_saxshell_icon(): + from saxshell.saxs.ui.branding import load_saxshell_icon + + return load_saxshell_icon() + + +def launch_cluster_run_file_ui( + *, + initial_project_dir: str | Path | None = None, + initial_frames_dir: str | Path | None = None, +) -> ClusterRunFileWindow: + from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + prepare_saxshell_application_identity, + ) + + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) + window = ClusterRunFileWindow( + initial_project_dir=initial_project_dir, + initial_frames_dir=initial_frames_dir, + ) + window.show() + window.raise_() + return window + + +__all__ = [ + "ClusterRunFileWindow", + "launch_cluster_run_file_ui", +] diff --git a/src/saxshell/clusterdynamics/__init__.py b/src/saxshell/clusterdynamics/__init__.py index 1ca2638..d9c85ed 100644 --- a/src/saxshell/clusterdynamics/__init__.py +++ b/src/saxshell/clusterdynamics/__init__.py @@ -8,6 +8,19 @@ load_cluster_dynamics_dataset, save_cluster_dynamics_dataset, ) +from .run_config import ( + ClusterDynamicsRunConfig, + ClusterDynamicsRunExecutionSummary, + build_clusterdynamics_run_config, + default_clusterdynamics_run_file_path, + load_clusterdynamics_run_config, + preview_clusterdynamics_run_config, + resolve_run_config_path, + run_clusterdynamics_run_config, + save_clusterdynamics_run_config, + suggest_clusterdynamics_output_file, + workflow_from_clusterdynamics_run_config, +) from .workflow import ( ClusterDynamicsResult, ClusterDynamicsSelectionPreview, @@ -24,8 +37,19 @@ "ClusterSizeLifetimeSummary", "LoadedClusterDynamicsDataset", "SavedClusterDynamicsDataset", + "ClusterDynamicsRunConfig", + "ClusterDynamicsRunExecutionSummary", + "build_clusterdynamics_run_config", + "default_clusterdynamics_run_file_path", "export_cluster_dynamics_colormap_csv", "export_cluster_dynamics_lifetime_csv", "load_cluster_dynamics_dataset", + "load_clusterdynamics_run_config", + "preview_clusterdynamics_run_config", + "resolve_run_config_path", + "run_clusterdynamics_run_config", "save_cluster_dynamics_dataset", + "save_clusterdynamics_run_config", + "suggest_clusterdynamics_output_file", + "workflow_from_clusterdynamics_run_config", ] diff --git a/src/saxshell/clusterdynamics/cli.py b/src/saxshell/clusterdynamics/cli.py index b60a0a9..13eba05 100644 --- a/src/saxshell/clusterdynamics/cli.py +++ b/src/saxshell/clusterdynamics/cli.py @@ -1,17 +1,149 @@ from __future__ import annotations import argparse +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from saxshell.version import __version__ -def main(argv: list[str] | None = None) -> int: +from .run_config import ( + default_clusterdynamics_run_file_path, + load_clusterdynamics_run_config, + run_clusterdynamics_run_config, +) + +_COMMANDS = {"setup-ui", "ui", "run", "batch-run"} +_TOP_LEVEL_OPTIONS = {"-h", "--help", "--version"} + + +def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog="clusterdynamics", description=( "Analyze time-binned cluster distributions from extracted PDB " - "or XYZ frame folders, or launch the Qt UI. Running without " - "additional arguments launches the UI." + "or XYZ frame folders. Running without a subcommand launches " + "the Qt UI." + ), + ) + parser.add_argument( + "--version", + action="store_true", + help="Show the clusterdynamics version number and exit.", + ) + subparsers = parser.add_subparsers(dest="command") + + setup_ui_parser = subparsers.add_parser( + "setup-ui", + help="Launch the beta project-backed run-file setup interface.", + ) + setup_ui_parser.add_argument( + "project_dir", + nargs="?", + type=Path, + help="Optional SAXSShell project folder.", + ) + setup_ui_parser.add_argument( + "--frames-dir", + type=Path, + default=None, + help="Optional extracted frames folder to prefill.", + ) + setup_ui_parser.add_argument( + "--energy-file", + type=Path, + default=None, + help="Optional CP2K .ener file to prefill.", + ) + setup_ui_parser.set_defaults(handler=_handle_setup_ui) + + ui_parser = subparsers.add_parser("ui", help="Launch the Qt UI.") + ui_parser.add_argument( + "frames_dir", + nargs="?", + type=Path, + help="Optional extracted frames directory to prefill in the UI.", + ) + ui_parser.add_argument( + "--energy-file", + type=Path, + default=None, + help="Optional CP2K .ener file to prefill in the UI.", + ) + ui_parser.add_argument( + "--project-dir", + type=Path, + default=None, + help="Optional SAXSShell project directory to prefill in the UI.", + ) + ui_parser.set_defaults(handler=_handle_ui) + + run_parser = subparsers.add_parser( + "run", + help="Run cluster dynamics from a project-backed run file.", + ) + run_parser.add_argument( + "project_dir", + type=Path, + help="SAXSShell project folder containing the run file.", + ) + run_parser.add_argument( + "--run-file", + type=Path, + default=None, + help=( + "Run file path. Defaults to cluster_dynamics_cli_run.json " + "in the project folder." ), ) + run_parser.set_defaults(handler=_handle_run) + + batch_parser = subparsers.add_parser( + "batch-run", + help="Run the default cluster dynamics run file for multiple projects.", + ) + batch_parser.add_argument( + "project_dirs", + nargs="+", + type=Path, + help="SAXSShell project folders to process.", + ) + batch_parser.add_argument( + "--keep-going", + action="store_true", + help="Continue running later projects if one project fails.", + ) + batch_parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of project runs to execute concurrently.", + ) + batch_parser.set_defaults(handler=_handle_batch_run) + return parser + + +def main(argv: list[str] | None = None) -> int: + raw_args = list(sys.argv[1:] if argv is None else argv) + if not raw_args or raw_args[0] not in _COMMANDS | _TOP_LEVEL_OPTIONS: + return _handle_legacy_ui(raw_args) + + parser = build_parser() + args = parser.parse_args(raw_args) + if args.version: + print(f"clusterdynamics {__version__}") + return 0 + try: + return int(args.handler(args)) + except Exception as exc: + parser.exit(2, f"Error: {exc}\n") + + +def _handle_legacy_ui(raw_args: list[str]) -> int: + parser = argparse.ArgumentParser( + prog="clusterdynamics", + description="Launch the SAXSShell clusterdynamics UI.", + ) parser.add_argument( "frames_dir", nargs="?", @@ -25,15 +157,200 @@ def main(argv: list[str] | None = None) -> int: "--project-dir", help="Optional SAXSShell project directory to prefill in the UI.", ) - args = parser.parse_args(argv) + args = parser.parse_args(raw_args) + return _launch_ui( + getattr(args, "frames_dir", None), + energy_file=getattr(args, "energy_file", None), + project_dir=getattr(args, "project_dir", None), + ) - from .ui.main_window import launch_clusterdynamics_ui - return launch_clusterdynamics_ui( +def _handle_setup_ui(args: argparse.Namespace) -> int: + from PySide6.QtWidgets import QApplication + + from .ui.run_file_window import launch_clusterdynamics_run_file_ui + + owns_app = QApplication.instance() is None + launch_clusterdynamics_run_file_ui( + initial_project_dir=getattr(args, "project_dir", None), + initial_frames_dir=getattr(args, "frames_dir", None), + initial_energy_file=getattr(args, "energy_file", None), + ) + app = QApplication.instance() + if owns_app and app is not None: + return app.exec() + return 0 + + +def _handle_ui(args: argparse.Namespace) -> int: + return _launch_ui( getattr(args, "frames_dir", None), energy_file=getattr(args, "energy_file", None), project_dir=getattr(args, "project_dir", None), ) -__all__ = ["main"] +def _launch_ui( + frames_dir: str | Path | None = None, + *, + energy_file: str | Path | None = None, + project_dir: str | Path | None = None, +) -> int: + from .ui.main_window import launch_clusterdynamics_ui + + return launch_clusterdynamics_ui( + frames_dir, + energy_file=energy_file, + project_dir=project_dir, + ) + + +def _handle_run(args: argparse.Namespace) -> int: + project_dir = Path(args.project_dir).expanduser().resolve() + run_file = _resolve_run_file(project_dir, args.run_file) + config = load_clusterdynamics_run_config(run_file) + summary = run_clusterdynamics_run_config( + project_dir, + config, + run_file_path=run_file, + log_callback=print, + progress_callback=_print_progress, + ) + _print_summary(summary) + return 0 + + +def _handle_batch_run(args: argparse.Namespace) -> int: + workers = max(int(getattr(args, "workers", 1)), 1) + project_dirs = [ + Path(project_dir_value).expanduser().resolve() + for project_dir_value in args.project_dirs + ] + if workers > 1: + return _handle_parallel_batch_run( + project_dirs, + workers=workers, + keep_going=bool(args.keep_going), + ) + + failures: list[tuple[Path, str]] = [] + for project_dir in project_dirs: + try: + run_file = default_clusterdynamics_run_file_path(project_dir) + config = load_clusterdynamics_run_config(run_file) + summary = run_clusterdynamics_run_config( + project_dir, + config, + run_file_path=run_file, + log_callback=print, + progress_callback=_print_progress, + ) + _print_summary(summary) + except Exception as exc: + failures.append((project_dir, str(exc))) + print(f"FAILED {project_dir}: {exc}") + if not bool(args.keep_going): + break + if failures: + print("") + print( + "Cluster dynamics batch completed with " + f"{len(failures)} failure(s)." + ) + return 1 + print("") + print("Cluster dynamics batch complete") + return 0 + + +def _handle_parallel_batch_run( + project_dirs: list[Path], + *, + workers: int, + keep_going: bool, +) -> int: + failures: list[tuple[Path, str]] = [] + with ThreadPoolExecutor(max_workers=workers) as executor: + future_to_project = { + executor.submit( + _run_project_collecting_logs, + project_dir, + ): project_dir + for project_dir in project_dirs + } + for future in as_completed(future_to_project): + project_dir = future_to_project[future] + try: + summary, log_lines = future.result() + except Exception as exc: + failures.append((project_dir, str(exc))) + print(f"FAILED {project_dir}: {exc}") + if not keep_going: + for pending in future_to_project: + if pending is not future: + pending.cancel() + break + continue + for line in log_lines: + print(line) + _print_summary(summary) + if failures: + print("") + print( + "Cluster dynamics batch completed with " + f"{len(failures)} failure(s)." + ) + return 1 + print("") + print("Cluster dynamics batch complete") + return 0 + + +def _run_project_collecting_logs(project_dir: Path): + log_lines: list[str] = [] + + def log(message: str) -> None: + log_lines.append(f"[{project_dir.name}] {message}") + + def progress(processed: int, total: int, frame_name: str) -> None: + log_lines.append( + f"[{project_dir.name}] {processed}/{total} {frame_name}" + ) + + run_file = default_clusterdynamics_run_file_path(project_dir) + config = load_clusterdynamics_run_config(run_file) + summary = run_clusterdynamics_run_config( + project_dir, + config, + run_file_path=run_file, + log_callback=log, + progress_callback=progress, + ) + return summary, log_lines + + +def _resolve_run_file(project_dir: Path, run_file: Path | None) -> Path: + if run_file is None: + return default_clusterdynamics_run_file_path(project_dir) + return Path(run_file).expanduser().resolve() + + +def _print_progress(processed: int, total: int, frame_name: str) -> None: + print(f"{processed}/{total} {frame_name}") + + +def _print_summary(summary) -> None: + print("") + print("Cluster dynamics CLI run complete") + print(f"Frames folder: {summary.frames_dir}") + print(f"Output dataset: {summary.output_file}") + print(f"Frames analyzed: {summary.result.analyzed_frames}") + print(f"Time bins: {summary.result.bin_count}") + print(f"Cluster labels: {len(summary.result.cluster_labels)}") + print(f"Lifetime rows: {len(summary.result.lifetime_by_label)}") + print(f"Files written: {summary.written_count}") + if summary.project_file is not None: + print(f"Project file: {summary.project_file}") + + +__all__ = ["build_parser", "main"] diff --git a/src/saxshell/clusterdynamics/run_config.py b/src/saxshell/clusterdynamics/run_config.py new file mode 100644 index 0000000..1c4d7f0 --- /dev/null +++ b/src/saxshell/clusterdynamics/run_config.py @@ -0,0 +1,588 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Callable + +from saxshell.cluster import ( + SEARCH_MODE_KDTREE, + PairCutoffDefinitions, + normalize_pair_cutoffs, + normalize_search_mode, +) +from saxshell.cluster.workflow import ( + example_atom_type_definitions, + example_pair_cutoff_definitions, +) +from saxshell.saxs.project_manager import ( + SAXSProjectManager, + build_project_paths, +) +from saxshell.structure import ( + AtomTypeDefinitions, + normalize_atom_type_definitions, +) + +from .dataset import SavedClusterDynamicsDataset, save_cluster_dynamics_dataset +from .workflow import ClusterDynamicsResult, ClusterDynamicsWorkflow + +DEFAULT_RUN_FILE_NAME = "cluster_dynamics_cli_run.json" +RUN_CONFIG_VERSION = 1 +ClusterDynamicsRunLogCallback = Callable[[str], None] +ClusterDynamicsRunProgressCallback = Callable[[int, int, str], None] + + +@dataclass(slots=True) +class ClusterDynamicsRunConfig: + frames_dir: str + output_file: str | None + energy_file: str | None = None + atom_type_definitions: AtomTypeDefinitions = field( + default_factory=example_atom_type_definitions + ) + pair_cutoff_definitions: PairCutoffDefinitions = field( + default_factory=example_pair_cutoff_definitions + ) + box_dimensions: tuple[float, float, float] | None = None + use_pbc: bool = False + default_cutoff: float | None = None + shell_levels: tuple[int, ...] = () + shared_shells: bool = False + include_shell_atoms_in_stoichiometry: bool = False + search_mode: str = SEARCH_MODE_KDTREE + folder_start_time_fs: float | None = None + first_frame_time_fs: float = 0.0 + frame_timestep_fs: float = 0.5 + frames_per_colormap_timestep: int = 1 + analysis_start_fs: float | None = None + analysis_stop_fs: float | None = None + created_at: str = field( + default_factory=lambda: datetime.now().isoformat(timespec="seconds") + ) + + def to_dict(self) -> dict[str, object]: + return { + "version": RUN_CONFIG_VERSION, + "created_at": self.created_at, + "frames_dir": self.frames_dir, + "output_file": self.output_file, + "energy_file": self.energy_file, + "atom_type_definitions": serialize_atom_type_definitions( + self.atom_type_definitions + ), + "pair_cutoff_definitions": serialize_pair_cutoff_definitions( + self.pair_cutoff_definitions + ), + "box_dimensions": self.box_dimensions, + "use_pbc": bool(self.use_pbc), + "default_cutoff": self.default_cutoff, + "shell_levels": [int(level) for level in self.shell_levels], + "shared_shells": bool(self.shared_shells), + "include_shell_atoms_in_stoichiometry": bool( + self.include_shell_atoms_in_stoichiometry + ), + "search_mode": normalize_search_mode(self.search_mode), + "folder_start_time_fs": self.folder_start_time_fs, + "first_frame_time_fs": float(self.first_frame_time_fs), + "frame_timestep_fs": float(self.frame_timestep_fs), + "frames_per_colormap_timestep": int( + self.frames_per_colormap_timestep + ), + "analysis_start_fs": self.analysis_start_fs, + "analysis_stop_fs": self.analysis_stop_fs, + } + + @classmethod + def from_dict( + cls, + payload: dict[str, object], + ) -> "ClusterDynamicsRunConfig": + frames_dir = str(payload.get("frames_dir", "")).strip() + if not frames_dir: + raise ValueError( + "Cluster dynamics run file is missing frames_dir." + ) + return cls( + frames_dir=frames_dir, + output_file=optional_text(payload.get("output_file")), + energy_file=optional_text(payload.get("energy_file")), + atom_type_definitions=coerce_atom_type_definitions( + payload.get("atom_type_definitions") + ), + pair_cutoff_definitions=coerce_pair_cutoff_definitions( + payload.get("pair_cutoff_definitions") + ), + box_dimensions=coerce_box_dimensions( + payload.get("box_dimensions") + ), + use_pbc=bool(payload.get("use_pbc", False)), + default_cutoff=optional_positive_float( + payload.get("default_cutoff") + ), + shell_levels=coerce_int_tuple(payload.get("shell_levels")), + shared_shells=bool(payload.get("shared_shells", False)), + include_shell_atoms_in_stoichiometry=bool( + payload.get("include_shell_atoms_in_stoichiometry", False) + ), + search_mode=normalize_search_mode( + str(payload.get("search_mode", SEARCH_MODE_KDTREE)) + ), + folder_start_time_fs=optional_float( + payload.get("folder_start_time_fs") + ), + first_frame_time_fs=float(payload.get("first_frame_time_fs", 0.0)), + frame_timestep_fs=float(payload.get("frame_timestep_fs", 0.5)), + frames_per_colormap_timestep=max( + int(payload.get("frames_per_colormap_timestep", 1)), + 1, + ), + analysis_start_fs=optional_float(payload.get("analysis_start_fs")), + analysis_stop_fs=optional_float(payload.get("analysis_stop_fs")), + created_at=str(payload.get("created_at", "")).strip() + or datetime.now().isoformat(timespec="seconds"), + ) + + +@dataclass(slots=True, frozen=True) +class ClusterDynamicsRunExecutionSummary: + project_dir: Path + run_file_path: Path | None + frames_dir: Path + output_file: Path + result: ClusterDynamicsResult + saved_dataset: SavedClusterDynamicsDataset + project_file: Path | None + + @property + def written_count(self) -> int: + return len(self.saved_dataset.written_files) + + +def default_clusterdynamics_run_file_path(project_dir: str | Path) -> Path: + return Path(project_dir).expanduser().resolve() / DEFAULT_RUN_FILE_NAME + + +def save_clusterdynamics_run_config( + output_path: str | Path, + config: ClusterDynamicsRunConfig, +) -> Path: + path = Path(output_path).expanduser().resolve() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(config.to_dict(), indent=2) + "\n", + encoding="utf-8", + ) + return path + + +def load_clusterdynamics_run_config( + run_file_path: str | Path, +) -> ClusterDynamicsRunConfig: + path = Path(run_file_path).expanduser().resolve() + payload = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError( + f"Cluster dynamics run file must contain a JSON object: {path}" + ) + return ClusterDynamicsRunConfig.from_dict(payload) + + +def path_text_for_run_config( + path: str | Path | None, + *, + project_dir: str | Path, +) -> str | None: + if path is None: + return None + resolved_project_dir = Path(project_dir).expanduser().resolve() + resolved_path = Path(path).expanduser().resolve() + try: + return resolved_path.relative_to(resolved_project_dir).as_posix() + except ValueError: + return str(resolved_path) + + +def resolve_run_config_path( + path_text: str | None, + *, + project_dir: str | Path, +) -> Path | None: + text = str(path_text or "").strip() + if not text: + return None + path = Path(text).expanduser() + if not path.is_absolute(): + path = Path(project_dir).expanduser().resolve() / path + return path.resolve() + + +def build_clusterdynamics_run_config( + *, + project_dir: str | Path, + frames_dir: str | Path, + output_file: str | Path | None, + energy_file: str | Path | None, + atom_type_definitions: AtomTypeDefinitions, + pair_cutoff_definitions: PairCutoffDefinitions, + box_dimensions: tuple[float, float, float] | None = None, + use_pbc: bool = False, + default_cutoff: float | None = None, + shell_levels: tuple[int, ...] = (), + shared_shells: bool = False, + include_shell_atoms_in_stoichiometry: bool = False, + search_mode: str = SEARCH_MODE_KDTREE, + folder_start_time_fs: float | None = None, + first_frame_time_fs: float = 0.0, + frame_timestep_fs: float = 0.5, + frames_per_colormap_timestep: int = 1, + analysis_start_fs: float | None = None, + analysis_stop_fs: float | None = None, +) -> ClusterDynamicsRunConfig: + return ClusterDynamicsRunConfig( + frames_dir=path_text_for_run_config( + frames_dir, + project_dir=project_dir, + ) + or "", + output_file=path_text_for_run_config( + output_file, + project_dir=project_dir, + ), + energy_file=path_text_for_run_config( + energy_file, + project_dir=project_dir, + ), + atom_type_definitions=normalize_atom_type_definitions( + atom_type_definitions + ), + pair_cutoff_definitions=normalize_pair_cutoffs( + pair_cutoff_definitions + ), + box_dimensions=box_dimensions, + use_pbc=bool(use_pbc), + default_cutoff=default_cutoff, + shell_levels=tuple(sorted({int(level) for level in shell_levels})), + shared_shells=bool(shared_shells), + include_shell_atoms_in_stoichiometry=bool( + include_shell_atoms_in_stoichiometry + ), + search_mode=normalize_search_mode(search_mode), + folder_start_time_fs=folder_start_time_fs, + first_frame_time_fs=float(first_frame_time_fs), + frame_timestep_fs=float(frame_timestep_fs), + frames_per_colormap_timestep=max( + int(frames_per_colormap_timestep), + 1, + ), + analysis_start_fs=analysis_start_fs, + analysis_stop_fs=analysis_stop_fs, + ) + + +def suggest_clusterdynamics_output_file( + *, + project_dir: str | Path, + frames_dir: str | Path | None, +) -> Path: + resolved_project_dir = Path(project_dir).expanduser().resolve() + dataset_dir = ( + build_project_paths(resolved_project_dir).exported_data_dir + / "clusterdynamics" + ) + folder_label = "cluster_dynamics" + if frames_dir is not None: + frames_path = Path(frames_dir).expanduser() + if frames_path.name: + folder_label = frames_path.name + return dataset_dir / f"{folder_label}_cluster_dynamics.json" + + +def workflow_from_clusterdynamics_run_config( + *, + project_dir: str | Path, + config: ClusterDynamicsRunConfig, +) -> ClusterDynamicsWorkflow: + resolved_project_dir = Path(project_dir).expanduser().resolve() + frames_dir = resolve_run_config_path( + config.frames_dir, + project_dir=resolved_project_dir, + ) + if frames_dir is None: + raise ValueError("Cluster dynamics run file is missing frames_dir.") + energy_file = resolve_run_config_path( + config.energy_file, + project_dir=resolved_project_dir, + ) + return ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=config.atom_type_definitions, + pair_cutoff_definitions=config.pair_cutoff_definitions, + box_dimensions=config.box_dimensions, + use_pbc=config.use_pbc, + default_cutoff=config.default_cutoff, + shell_levels=config.shell_levels, + shared_shells=config.shared_shells, + include_shell_atoms_in_stoichiometry=( + config.include_shell_atoms_in_stoichiometry + ), + search_mode=config.search_mode, + folder_start_time_fs=config.folder_start_time_fs, + first_frame_time_fs=config.first_frame_time_fs, + frame_timestep_fs=config.frame_timestep_fs, + frames_per_colormap_timestep=config.frames_per_colormap_timestep, + analysis_start_fs=config.analysis_start_fs, + analysis_stop_fs=config.analysis_stop_fs, + energy_file=energy_file, + ) + + +def preview_clusterdynamics_run_config( + *, + project_dir: str | Path, + config: ClusterDynamicsRunConfig, +) -> dict[str, object]: + workflow = workflow_from_clusterdynamics_run_config( + project_dir=project_dir, + config=config, + ) + return workflow.preview_selection().to_dict() + + +def run_clusterdynamics_run_config( + project_dir: str | Path, + config: ClusterDynamicsRunConfig, + *, + run_file_path: str | Path | None = None, + log_callback: ClusterDynamicsRunLogCallback | None = None, + progress_callback: ClusterDynamicsRunProgressCallback | None = None, +) -> ClusterDynamicsRunExecutionSummary: + resolved_project_dir = Path(project_dir).expanduser().resolve() + frames_dir = resolve_run_config_path( + config.frames_dir, + project_dir=resolved_project_dir, + ) + if frames_dir is None: + raise ValueError("Cluster dynamics run file is missing frames_dir.") + output_file = resolve_run_config_path( + config.output_file, + project_dir=resolved_project_dir, + ) + if output_file is None: + output_file = suggest_clusterdynamics_output_file( + project_dir=resolved_project_dir, + frames_dir=frames_dir, + ) + workflow = workflow_from_clusterdynamics_run_config( + project_dir=resolved_project_dir, + config=config, + ) + _emit_log(log_callback, f"Frames folder: {frames_dir}") + _emit_log(log_callback, f"Output dataset: {output_file}") + result = workflow.analyze(progress_callback=progress_callback) + saved = save_cluster_dynamics_dataset( + result, + output_file, + analysis_settings=config.to_dict(), + ) + project_file = _register_project_inputs( + resolved_project_dir, + frames_dir=frames_dir, + energy_file=resolve_run_config_path( + config.energy_file, + project_dir=resolved_project_dir, + ), + ) + _emit_log( + log_callback, + f"Saved cluster-dynamics dataset: {saved.dataset_file}", + ) + return ClusterDynamicsRunExecutionSummary( + project_dir=resolved_project_dir, + run_file_path=( + None if run_file_path is None else Path(run_file_path).resolve() + ), + frames_dir=frames_dir, + output_file=saved.dataset_file, + result=result, + saved_dataset=saved, + project_file=project_file, + ) + + +def serialize_atom_type_definitions( + definitions: AtomTypeDefinitions, +) -> dict[str, list[dict[str, str | None]]]: + normalized = normalize_atom_type_definitions(definitions) + return { + atom_type: [ + {"element": element, "residue": residue} + for element, residue in entries + ] + for atom_type, entries in normalized.items() + } + + +def serialize_pair_cutoff_definitions( + definitions: PairCutoffDefinitions, +) -> list[dict[str, object]]: + normalized = normalize_pair_cutoffs(definitions) + payload: list[dict[str, object]] = [] + for atom1, atom2 in sorted(normalized): + payload.append( + { + "atom1": atom1, + "atom2": atom2, + "shell_cutoffs": { + str(level): float(cutoff) + for level, cutoff in sorted( + normalized[(atom1, atom2)].items() + ) + }, + } + ) + return payload + + +def coerce_atom_type_definitions(value: object) -> AtomTypeDefinitions: + if not isinstance(value, dict): + return example_atom_type_definitions() + definitions: AtomTypeDefinitions = {} + for atom_type, entries in value.items(): + if not isinstance(entries, list): + continue + parsed: list[tuple[str, str | None]] = [] + for entry in entries: + if isinstance(entry, dict): + element_value = entry.get("element") + residue_value = entry.get("residue") + elif isinstance(entry, (list, tuple)): + element_value = entry[0] if len(entry) >= 1 else None + residue_value = entry[1] if len(entry) >= 2 else None + else: + element_value = entry + residue_value = None + element = str(element_value or "").strip().title() + residue_text = str(residue_value or "").strip() + if element: + parsed.append((element, residue_text or None)) + if parsed: + definitions[str(atom_type).strip()] = parsed + return normalize_atom_type_definitions(definitions) + + +def coerce_pair_cutoff_definitions(value: object) -> PairCutoffDefinitions: + if not isinstance(value, list): + return example_pair_cutoff_definitions() + definitions: PairCutoffDefinitions = {} + for entry in value: + if not isinstance(entry, dict): + continue + atom1 = str(entry.get("atom1", "")).strip().title() + atom2 = str(entry.get("atom2", "")).strip().title() + cutoffs = entry.get("shell_cutoffs") + if not atom1 or not atom2 or not isinstance(cutoffs, dict): + continue + parsed: dict[int, float] = {} + for level, cutoff in cutoffs.items(): + parsed[int(level)] = float(cutoff) + if parsed: + definitions[(atom1, atom2)] = parsed + return normalize_pair_cutoffs(definitions) + + +def coerce_box_dimensions( + value: object, +) -> tuple[float, float, float] | None: + if value is None: + return None + if not isinstance(value, (list, tuple)): + raise ValueError("box_dimensions must be a list of three numbers.") + box = tuple(float(component) for component in value) + if len(box) != 3: + raise ValueError("box_dimensions must contain exactly three numbers.") + return box + + +def coerce_int_tuple(value: object) -> tuple[int, ...]: + if not isinstance(value, (list, tuple)): + return () + return tuple(sorted({int(entry) for entry in value})) + + +def optional_text(value: object) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None + + +def optional_float(value: object) -> float | None: + if value is None: + return None + text = str(value).strip() + if not text: + return None + return float(text) + + +def optional_positive_float(value: object) -> float | None: + result = optional_float(value) + if result is None: + return None + return result if result > 0.0 else None + + +def _register_project_inputs( + project_dir: Path, + *, + frames_dir: Path, + energy_file: Path | None, +) -> Path | None: + project_file = build_project_paths(project_dir).project_file + if not project_file.is_file(): + return None + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.frames_dir = str(Path(frames_dir).expanduser().resolve()) + settings.energy_file = ( + None + if energy_file is None + else str(Path(energy_file).expanduser().resolve()) + ) + return manager.save_project(settings) + + +def _emit_log( + callback: ClusterDynamicsRunLogCallback | None, + message: str, +) -> None: + if callback is not None: + callback(str(message).strip()) + + +__all__ = [ + "DEFAULT_RUN_FILE_NAME", + "ClusterDynamicsRunConfig", + "ClusterDynamicsRunExecutionSummary", + "build_clusterdynamics_run_config", + "coerce_atom_type_definitions", + "coerce_box_dimensions", + "coerce_int_tuple", + "coerce_pair_cutoff_definitions", + "default_clusterdynamics_run_file_path", + "load_clusterdynamics_run_config", + "optional_float", + "optional_positive_float", + "optional_text", + "path_text_for_run_config", + "preview_clusterdynamics_run_config", + "resolve_run_config_path", + "run_clusterdynamics_run_config", + "save_clusterdynamics_run_config", + "serialize_atom_type_definitions", + "serialize_pair_cutoff_definitions", + "suggest_clusterdynamics_output_file", + "workflow_from_clusterdynamics_run_config", +] diff --git a/src/saxshell/clusterdynamics/ui/__init__.py b/src/saxshell/clusterdynamics/ui/__init__.py index 2682292..00bae7b 100644 --- a/src/saxshell/clusterdynamics/ui/__init__.py +++ b/src/saxshell/clusterdynamics/ui/__init__.py @@ -1,5 +1,14 @@ """Qt UI for the clusterdynamics application.""" from .main_window import ClusterDynamicsMainWindow, launch_clusterdynamics_ui +from .run_file_window import ( + ClusterDynamicsRunFileWindow, + launch_clusterdynamics_run_file_ui, +) -__all__ = ["ClusterDynamicsMainWindow", "launch_clusterdynamics_ui"] +__all__ = [ + "ClusterDynamicsMainWindow", + "ClusterDynamicsRunFileWindow", + "launch_clusterdynamics_run_file_ui", + "launch_clusterdynamics_ui", +] diff --git a/src/saxshell/clusterdynamics/ui/run_file_window.py b/src/saxshell/clusterdynamics/ui/run_file_window.py new file mode 100644 index 0000000..f9b1989 --- /dev/null +++ b/src/saxshell/clusterdynamics/ui/run_file_window.py @@ -0,0 +1,530 @@ +from __future__ import annotations + +import json +import sys +from pathlib import Path + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QApplication, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QMainWindow, + QMessageBox, + QPlainTextEdit, + QPushButton, + QScrollArea, + QSplitter, + QVBoxLayout, + QWidget, +) + +from saxshell.cluster import DEFAULT_CLUSTER_EXTRACTION_PRESET_NAME +from saxshell.cluster.ui.definitions_panel import ClusterDefinitionsPanel +from saxshell.cluster.workflow import ClusterWorkflow, format_box_dimensions +from saxshell.clusterdynamics.run_config import ( + build_clusterdynamics_run_config, + default_clusterdynamics_run_file_path, + preview_clusterdynamics_run_config, + save_clusterdynamics_run_config, + suggest_clusterdynamics_output_file, +) +from saxshell.clusterdynamics.ui.main_window import ClusterDynamicsTimePanel +from saxshell.saxs.project_manager import SAXSProjectManager +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) + + +class ClusterDynamicsRunFileWindow(QMainWindow): + def __init__( + self, + *, + initial_project_dir: str | Path | None = None, + initial_frames_dir: str | Path | None = None, + initial_energy_file: str | Path | None = None, + ) -> None: + super().__init__() + self._browse_start_dir = Path.home() + self._last_summary: dict[str, object] | None = None + self._last_suggested_output_file: str | None = None + + project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + frames_dir = ( + None + if initial_frames_dir is None + else Path(initial_frames_dir).expanduser().resolve() + ) + energy_file = ( + None + if initial_energy_file is None + else Path(initial_energy_file).expanduser().resolve() + ) + if project_dir is not None: + self._browse_start_dir = project_dir + defaults = self._project_defaults(project_dir) + if frames_dir is None: + frames_dir = defaults.get("frames_dir") + if energy_file is None: + energy_file = defaults.get("energy_file") + + self.setWindowTitle("Cluster Dynamics CLI Setup (Beta)") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1120, 800) + self._build_ui() + self.definitions_panel.load_preset( + DEFAULT_CLUSTER_EXTRACTION_PRESET_NAME + ) + + if project_dir is not None: + self.project_dir_edit.setText(str(project_dir)) + self._refresh_run_file_path() + if frames_dir is not None and frames_dir.is_dir(): + self.frames_dir_edit.setText(str(frames_dir)) + self._browse_start_dir = frames_dir + if energy_file is not None and energy_file.is_file(): + self.energy_file_edit.setText(str(energy_file)) + self._inspect_frames() + self._update_preview() + + def _build_ui(self) -> None: + central = QWidget(self) + root = QVBoxLayout(central) + root.setContentsMargins(10, 10, 10, 10) + root.setSpacing(8) + self.setCentralWidget(central) + + splitter = QSplitter(Qt.Orientation.Horizontal, self) + splitter.setChildrenCollapsible(False) + root.addWidget(splitter, stretch=1) + + left_scroll = QScrollArea(self) + left_scroll.setWidgetResizable(True) + left_panel = QWidget() + self.left_layout = QVBoxLayout(left_panel) + self.left_layout.setContentsMargins(10, 10, 10, 10) + self.left_layout.setSpacing(10) + left_scroll.setWidget(left_panel) + + right_scroll = QScrollArea(self) + right_scroll.setWidgetResizable(True) + right_panel = QWidget() + self.right_layout = QVBoxLayout(right_panel) + self.right_layout.setContentsMargins(10, 10, 10, 10) + self.right_layout.setSpacing(10) + right_scroll.setWidget(right_panel) + + splitter.addWidget(left_scroll) + splitter.addWidget(right_scroll) + splitter.setSizes([580, 540]) + + self.left_layout.addWidget(self._build_project_group()) + self.left_layout.addWidget(self._build_input_group()) + self.definitions_panel = ClusterDefinitionsPanel() + self.definitions_panel.settings_changed.connect(self._update_preview) + self.left_layout.addWidget(self.definitions_panel) + self.time_panel = ClusterDynamicsTimePanel() + self.time_panel.settings_changed.connect(self._update_preview) + self.left_layout.addWidget(self.time_panel) + self.left_layout.addWidget(self._build_save_group()) + self.left_layout.addStretch(1) + + self.right_layout.addWidget(self._build_inspection_group()) + self.right_layout.addWidget(self._build_command_group()) + self.right_layout.addStretch(1) + self.statusBar().showMessage("Ready") + + def _build_project_group(self) -> QGroupBox: + group = QGroupBox("Project") + form = QFormLayout(group) + project_row = QHBoxLayout() + self.project_dir_edit = QLineEdit() + self.project_dir_edit.editingFinished.connect( + self._on_project_dir_changed + ) + project_row.addWidget(self.project_dir_edit, stretch=1) + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self._browse_project_dir) + project_row.addWidget(browse_button) + project_widget = QWidget() + project_widget.setLayout(project_row) + form.addRow("Project folder", project_widget) + + self.run_file_edit = QLineEdit() + self.run_file_edit.setReadOnly(True) + form.addRow("Run file", self.run_file_edit) + return group + + def _build_input_group(self) -> QGroupBox: + group = QGroupBox("Input / Output") + form = QFormLayout(group) + self.frames_dir_edit = QLineEdit() + self.frames_dir_edit.editingFinished.connect(self._inspect_frames) + form.addRow( + "Frames folder", + self._make_path_row( + self.frames_dir_edit, + self._browse_frames_dir, + ), + ) + + self.energy_file_edit = QLineEdit() + self.energy_file_edit.editingFinished.connect(self._update_preview) + form.addRow( + "CP2K .ener file", + self._make_path_row( + self.energy_file_edit, + self._browse_energy_file, + ), + ) + + self.output_file_edit = QLineEdit() + self.output_file_edit.editingFinished.connect(self._update_preview) + form.addRow( + "Output dataset", + self._make_path_row( + self.output_file_edit, + self._browse_output_file, + ), + ) + return group + + def _build_save_group(self) -> QGroupBox: + group = QGroupBox("Save") + layout = QHBoxLayout(group) + inspect_button = QPushButton("Inspect Frames") + inspect_button.clicked.connect(self._inspect_frames) + layout.addWidget(inspect_button) + save_button = QPushButton("Save Run File") + save_button.clicked.connect(self._save_run_file) + layout.addWidget(save_button) + layout.addStretch(1) + return group + + def _build_inspection_group(self) -> QGroupBox: + group = QGroupBox("Inspection") + layout = QVBoxLayout(group) + self.inspection_box = QPlainTextEdit() + self.inspection_box.setReadOnly(True) + self.inspection_box.setMinimumHeight(210) + layout.addWidget(self.inspection_box) + return group + + def _build_command_group(self) -> QGroupBox: + group = QGroupBox("CLI Command / JSON") + layout = QVBoxLayout(group) + layout.addWidget(QLabel("Commands")) + self.command_box = QPlainTextEdit() + self.command_box.setReadOnly(True) + self.command_box.setMinimumHeight(150) + layout.addWidget(self.command_box) + layout.addWidget(QLabel("Run file preview")) + self.json_preview_box = QPlainTextEdit() + self.json_preview_box.setReadOnly(True) + self.json_preview_box.setMinimumHeight(300) + layout.addWidget(self.json_preview_box) + return group + + def _make_path_row(self, line_edit: QLineEdit, callback) -> QWidget: + widget = QWidget() + row = QHBoxLayout(widget) + row.setContentsMargins(0, 0, 0, 0) + row.addWidget(line_edit, stretch=1) + button = QPushButton("Browse...") + button.clicked.connect(callback) + row.addWidget(button) + return widget + + def _browse_project_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select project folder", + str(self._browse_start_dir), + ) + if not selected: + return + self.project_dir_edit.setText(selected) + self._on_project_dir_changed() + + def _browse_frames_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select extracted frames folder", + str(self._browse_start_dir), + ) + if not selected: + return + self.frames_dir_edit.setText(selected) + self._browse_start_dir = Path(selected).expanduser().resolve() + self._inspect_frames() + + def _browse_energy_file(self, *_args: object) -> None: + path, _selected_filter = QFileDialog.getOpenFileName( + self, + "Select CP2K .ener file", + self.energy_file_edit.text().strip() + or str(self._browse_start_dir), + "Energy Files (*.ener);;All Files (*)", + ) + if path: + self.energy_file_edit.setText(path) + self._update_preview() + + def _browse_output_file(self, *_args: object) -> None: + path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Select output dataset", + self.output_file_edit.text().strip() + or str(self._browse_start_dir), + "JSON Files (*.json);;All Files (*)", + ) + if path: + self.output_file_edit.setText(path) + self._update_preview() + + def _on_project_dir_changed(self, *_args: object) -> None: + project_dir = self._project_dir() + if project_dir is None: + return + self._browse_start_dir = project_dir + self._refresh_run_file_path() + defaults = self._project_defaults(project_dir) + if not self.frames_dir_edit.text().strip(): + frames_dir = defaults.get("frames_dir") + if frames_dir is not None and frames_dir.is_dir(): + self.frames_dir_edit.setText(str(frames_dir)) + if not self.energy_file_edit.text().strip(): + energy_file = defaults.get("energy_file") + if energy_file is not None and energy_file.is_file(): + self.energy_file_edit.setText(str(energy_file)) + self._inspect_frames() + + def _inspect_frames(self, *_args: object) -> None: + frames_text = self.frames_dir_edit.text().strip() + if not frames_text: + self._last_summary = None + self.inspection_box.setPlainText("No frames folder selected.") + self._update_preview() + return + try: + workflow = ClusterWorkflow( + frames_dir=frames_text, + atom_type_definitions={}, + pair_cutoff_definitions={}, + ) + summary = workflow.inspect() + except Exception as exc: + self._last_summary = None + self.inspection_box.setPlainText(str(exc)) + self.statusBar().showMessage("Frames inspection failed") + self._update_preview() + return + self._last_summary = summary + self.definitions_panel.set_frame_mode( + str(summary.get("frame_format", "") or "") + ) + self.inspection_box.setPlainText(_summary_text(summary)) + self._refresh_suggested_output_file() + self.statusBar().showMessage( + f"Discovered {int(summary.get('n_frames', 0))} frame(s)" + ) + self._update_preview() + + def _refresh_run_file_path(self) -> None: + project_dir = self._project_dir() + if project_dir is None: + self.run_file_edit.clear() + return + self.run_file_edit.setText( + str(default_clusterdynamics_run_file_path(project_dir)) + ) + + def _refresh_suggested_output_file(self) -> None: + project_dir = self._project_dir() + frames_text = self.frames_dir_edit.text().strip() + if project_dir is None or not frames_text: + return + try: + suggested = suggest_clusterdynamics_output_file( + project_dir=project_dir, + frames_dir=frames_text, + ) + except Exception: + return + current = self.output_file_edit.text().strip() + if not current or current == self._last_suggested_output_file: + self.output_file_edit.setText(str(suggested)) + self._last_suggested_output_file = str(suggested) + + def _save_run_file(self, *_args: object) -> None: + try: + project_dir = self._require_project_dir() + config = self._current_config(project_dir) + except Exception as exc: + QMessageBox.warning(self, "Cluster Dynamics CLI Setup", str(exc)) + return + run_file_path = default_clusterdynamics_run_file_path(project_dir) + save_clusterdynamics_run_config(run_file_path, config) + self.run_file_edit.setText(str(run_file_path)) + self.json_preview_box.setPlainText(save_preview_text(config.to_dict())) + self._update_preview() + self.statusBar().showMessage(f"Saved run file: {run_file_path}") + QMessageBox.information( + self, + "Cluster Dynamics CLI Setup", + f"Saved cluster dynamics CLI run file:\n{run_file_path}", + ) + + def _update_preview(self, *_args: object) -> None: + project_dir = self._project_dir() + if project_dir is None: + self.command_box.setPlainText( + "Select a project folder before saving the CLI run file." + ) + self.json_preview_box.clear() + return + self._refresh_run_file_path() + self.command_box.setPlainText( + f'clusterdynamics run "{project_dir}"\n' + f'saxshell clusterdynamics run "{project_dir}"' + ) + try: + config = self._current_config(project_dir) + payload = config.to_dict() + try: + payload["selection_preview"] = ( + preview_clusterdynamics_run_config( + project_dir=project_dir, + config=config, + ) + ) + except Exception as exc: + payload["selection_preview_error"] = str(exc) + except Exception as exc: + self.json_preview_box.setPlainText(str(exc)) + return + self.json_preview_box.setPlainText(save_preview_text(payload)) + + def _current_config(self, project_dir: Path): + frames_text = self.frames_dir_edit.text().strip() + if not frames_text: + raise ValueError("Choose a frames folder before saving.") + output_text = self.output_file_edit.text().strip() + energy_text = self.energy_file_edit.text().strip() + return build_clusterdynamics_run_config( + project_dir=project_dir, + frames_dir=frames_text, + output_file=output_text or None, + energy_file=energy_text or None, + atom_type_definitions=self.definitions_panel.atom_type_definitions(), + pair_cutoff_definitions=( + self.definitions_panel.pair_cutoff_definitions() + ), + box_dimensions=self.definitions_panel.box_dimensions(), + use_pbc=self.definitions_panel.use_pbc(), + default_cutoff=self.definitions_panel.default_cutoff(), + shell_levels=self.definitions_panel.shell_growth_levels(), + shared_shells=self.definitions_panel.shared_shells(), + include_shell_atoms_in_stoichiometry=( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + search_mode=self.definitions_panel.search_mode(), + folder_start_time_fs=self.time_panel.folder_start_time_fs(), + first_frame_time_fs=self.time_panel.first_frame_time_fs(), + frame_timestep_fs=self.time_panel.frame_timestep_fs(), + frames_per_colormap_timestep=( + self.time_panel.frames_per_colormap_timestep() + ), + analysis_start_fs=self.time_panel.analysis_start_fs(), + analysis_stop_fs=self.time_panel.analysis_stop_fs(), + ) + + def _project_dir(self) -> Path | None: + text = self.project_dir_edit.text().strip() + if not text: + return None + return Path(text).expanduser().resolve() + + def _require_project_dir(self) -> Path: + project_dir = self._project_dir() + if project_dir is None: + raise ValueError("Choose a project folder before saving.") + if not project_dir.is_dir(): + raise ValueError(f"Project folder does not exist: {project_dir}") + return project_dir + + @staticmethod + def _project_defaults(project_dir: Path) -> dict[str, Path | None]: + defaults: dict[str, Path | None] = { + "frames_dir": None, + "energy_file": None, + } + try: + settings = SAXSProjectManager().load_project(project_dir) + except Exception: + return defaults + defaults["frames_dir"] = settings.resolved_frames_dir + defaults["energy_file"] = settings.resolved_energy_file + return defaults + + +def _summary_text(summary: dict[str, object]) -> str: + box_dimensions = summary.get("box_dimensions") + if box_dimensions is None: + box_dimensions = summary.get("estimated_box_dimensions") + source_kind = summary.get("box_dimensions_source_kind") + label = ( + "Source box dimensions" + if source_kind == "source_filename" + else "Estimated box dimensions" + ) + lines = [ + f"Frames folder: {summary.get('input_dir')}", + f"Mode: {summary.get('mode_label')}", + f"Frames: {summary.get('n_frames')}", + f"{label}: {format_box_dimensions(box_dimensions)}", + ] + if summary.get("box_dimensions_source") is not None: + lines.append(f"Box source: {summary.get('box_dimensions_source')}") + return "\n".join(lines) + + +def save_preview_text(payload: dict[str, object]) -> str: + return json.dumps(payload, indent=2) + + +def launch_clusterdynamics_run_file_ui( + *, + initial_project_dir: str | Path | None = None, + initial_frames_dir: str | Path | None = None, + initial_energy_file: str | Path | None = None, +) -> ClusterDynamicsRunFileWindow: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) + window = ClusterDynamicsRunFileWindow( + initial_project_dir=initial_project_dir, + initial_frames_dir=initial_frames_dir, + initial_energy_file=initial_energy_file, + ) + window.show() + window.raise_() + return window + + +__all__ = [ + "ClusterDynamicsRunFileWindow", + "launch_clusterdynamics_run_file_ui", +] diff --git a/src/saxshell/clusterdynamicsml/__init__.py b/src/saxshell/clusterdynamicsml/__init__.py index c638865..d5a855d 100644 --- a/src/saxshell/clusterdynamicsml/__init__.py +++ b/src/saxshell/clusterdynamicsml/__init__.py @@ -6,6 +6,18 @@ load_cluster_dynamicsai_dataset, save_cluster_dynamicsai_dataset, ) +from .run_config import ( + ClusterDynamicsMLRunConfig, + ClusterDynamicsMLRunExecutionSummary, + build_clusterdynamicsml_run_config, + default_clusterdynamicsml_run_file_path, + load_clusterdynamicsml_run_config, + preview_clusterdynamicsml_run_config, + run_clusterdynamicsml_run_config, + save_clusterdynamicsml_run_config, + suggest_clusterdynamicsml_output_file, + workflow_from_clusterdynamicsml_run_config, +) from .workflow import ( ClusterDynamicsMLPreview, ClusterDynamicsMLResult, @@ -30,6 +42,16 @@ "SAXSComponentWeight", "LoadedClusterDynamicsMLDataset", "SavedClusterDynamicsMLDataset", + "ClusterDynamicsMLRunConfig", + "ClusterDynamicsMLRunExecutionSummary", + "build_clusterdynamicsml_run_config", + "default_clusterdynamicsml_run_file_path", "load_cluster_dynamicsai_dataset", + "load_clusterdynamicsml_run_config", + "preview_clusterdynamicsml_run_config", + "run_clusterdynamicsml_run_config", "save_cluster_dynamicsai_dataset", + "save_clusterdynamicsml_run_config", + "suggest_clusterdynamicsml_output_file", + "workflow_from_clusterdynamicsml_run_config", ] diff --git a/src/saxshell/clusterdynamicsml/cli.py b/src/saxshell/clusterdynamicsml/cli.py index 9bf2e4f..cff71d2 100644 --- a/src/saxshell/clusterdynamicsml/cli.py +++ b/src/saxshell/clusterdynamicsml/cli.py @@ -1,5 +1,358 @@ from __future__ import annotations -from .ui.main_window import main +import argparse +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path -__all__ = ["main"] +from saxshell.version import __version__ + +from .run_config import ( + default_clusterdynamicsml_run_file_path, + load_clusterdynamicsml_run_config, + run_clusterdynamicsml_run_config, +) + +_COMMANDS = {"setup-ui", "ui", "run", "batch-run"} +_TOP_LEVEL_OPTIONS = {"-h", "--help", "--version"} + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="clusterdynamicsml", + description=( + "Predict larger-cluster stoichiometries, representative " + "structures, and cluster-only SAXS traces. Running without a " + "subcommand launches the Qt UI." + ), + ) + parser.add_argument( + "--version", + action="store_true", + help="Show the clusterdynamicsml version number and exit.", + ) + subparsers = parser.add_subparsers(dest="command") + + setup_ui_parser = subparsers.add_parser( + "setup-ui", + help="Launch the beta project-backed run-file setup interface.", + ) + setup_ui_parser.add_argument( + "project_dir", + nargs="?", + type=Path, + help="Optional SAXSShell project folder.", + ) + setup_ui_parser.add_argument( + "--frames-dir", + type=Path, + default=None, + help="Optional extracted frames folder to prefill.", + ) + setup_ui_parser.add_argument( + "--energy-file", + type=Path, + default=None, + help="Optional CP2K .ener file to prefill.", + ) + setup_ui_parser.add_argument( + "--clusters-dir", + type=Path, + default=None, + help="Optional smaller-cluster structure directory to prefill.", + ) + setup_ui_parser.add_argument( + "--experimental-data", + type=Path, + default=None, + help="Optional experimental SAXS data file to prefill.", + ) + setup_ui_parser.set_defaults(handler=_handle_setup_ui) + + ui_parser = subparsers.add_parser("ui", help="Launch the Qt UI.") + _add_ui_prefill_arguments(ui_parser) + ui_parser.set_defaults(handler=_handle_ui) + + run_parser = subparsers.add_parser( + "run", + help="Run cluster dynamics ML from a project-backed run file.", + ) + run_parser.add_argument( + "project_dir", + type=Path, + help="SAXSShell project folder containing the run file.", + ) + run_parser.add_argument( + "--run-file", + type=Path, + default=None, + help=( + "Run file path. Defaults to cluster_dynamics_ml_cli_run.json " + "in the project folder." + ), + ) + run_parser.set_defaults(handler=_handle_run) + + batch_parser = subparsers.add_parser( + "batch-run", + help=( + "Run the default cluster dynamics ML run file for multiple " + "projects." + ), + ) + batch_parser.add_argument( + "project_dirs", + nargs="+", + type=Path, + help="SAXSShell project folders to process.", + ) + batch_parser.add_argument( + "--keep-going", + action="store_true", + help="Continue running later projects if one project fails.", + ) + batch_parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of project runs to execute concurrently.", + ) + batch_parser.set_defaults(handler=_handle_batch_run) + return parser + + +def main(argv: list[str] | None = None) -> int: + raw_args = list(sys.argv[1:] if argv is None else argv) + if not raw_args or raw_args[0] not in _COMMANDS | _TOP_LEVEL_OPTIONS: + return _handle_legacy_ui(raw_args) + + parser = build_parser() + args = parser.parse_args(raw_args) + if args.version: + print(f"clusterdynamicsml {__version__}") + return 0 + try: + return int(args.handler(args)) + except Exception as exc: + parser.exit(2, f"Error: {exc}\n") + + +def _add_ui_prefill_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "frames_dir", + nargs="?", + type=Path, + help="Optional extracted frames directory to prefill in the UI.", + ) + parser.add_argument( + "--energy-file", + type=Path, + default=None, + help="Optional CP2K .ener file to prefill in the UI.", + ) + parser.add_argument( + "--project-dir", + type=Path, + default=None, + help="Optional SAXSShell project directory to prefill in the UI.", + ) + parser.add_argument( + "--clusters-dir", + type=Path, + default=None, + help="Optional smaller-cluster structure directory to prefill in the UI.", + ) + parser.add_argument( + "--experimental-data", + type=Path, + default=None, + help="Optional experimental SAXS data file to prefill in the UI.", + ) + + +def _handle_legacy_ui(raw_args: list[str]) -> int: + parser = argparse.ArgumentParser( + prog="clusterdynamicsml", + description="Launch the SAXSShell clusterdynamicsml UI.", + ) + _add_ui_prefill_arguments(parser) + args = parser.parse_args(raw_args) + return _launch_ui_from_args(args) + + +def _handle_setup_ui(args: argparse.Namespace) -> int: + from PySide6.QtWidgets import QApplication + + from .ui.run_file_window import launch_clusterdynamicsml_run_file_ui + + owns_app = QApplication.instance() is None + launch_clusterdynamicsml_run_file_ui( + initial_project_dir=getattr(args, "project_dir", None), + initial_frames_dir=getattr(args, "frames_dir", None), + initial_energy_file=getattr(args, "energy_file", None), + initial_clusters_dir=getattr(args, "clusters_dir", None), + initial_experimental_data_file=getattr( + args, "experimental_data", None + ), + ) + app = QApplication.instance() + if owns_app and app is not None: + return app.exec() + return 0 + + +def _handle_ui(args: argparse.Namespace) -> int: + return _launch_ui_from_args(args) + + +def _launch_ui_from_args(args: argparse.Namespace) -> int: + from .ui.main_window import launch_clusterdynamicsml_ui + + return launch_clusterdynamicsml_ui( + getattr(args, "frames_dir", None), + energy_file=getattr(args, "energy_file", None), + project_dir=getattr(args, "project_dir", None), + clusters_dir=getattr(args, "clusters_dir", None), + experimental_data_file=getattr(args, "experimental_data", None), + ) + + +def _handle_run(args: argparse.Namespace) -> int: + project_dir = Path(args.project_dir).expanduser().resolve() + run_file = _resolve_run_file(project_dir, args.run_file) + config = load_clusterdynamicsml_run_config(run_file) + summary = run_clusterdynamicsml_run_config( + project_dir, + config, + run_file_path=run_file, + log_callback=print, + ) + _print_summary(summary) + return 0 + + +def _handle_batch_run(args: argparse.Namespace) -> int: + workers = max(int(getattr(args, "workers", 1)), 1) + project_dirs = [ + Path(project_dir_value).expanduser().resolve() + for project_dir_value in args.project_dirs + ] + if workers > 1: + return _handle_parallel_batch_run( + project_dirs, + workers=workers, + keep_going=bool(args.keep_going), + ) + + failures: list[tuple[Path, str]] = [] + for project_dir in project_dirs: + try: + run_file = default_clusterdynamicsml_run_file_path(project_dir) + config = load_clusterdynamicsml_run_config(run_file) + summary = run_clusterdynamicsml_run_config( + project_dir, + config, + run_file_path=run_file, + log_callback=print, + ) + _print_summary(summary) + except Exception as exc: + failures.append((project_dir, str(exc))) + print(f"FAILED {project_dir}: {exc}") + if not bool(args.keep_going): + break + if failures: + print("") + print( + "Cluster dynamics ML batch completed with " + f"{len(failures)} failure(s)." + ) + return 1 + print("") + print("Cluster dynamics ML batch complete") + return 0 + + +def _handle_parallel_batch_run( + project_dirs: list[Path], + *, + workers: int, + keep_going: bool, +) -> int: + failures: list[tuple[Path, str]] = [] + with ThreadPoolExecutor(max_workers=workers) as executor: + future_to_project = { + executor.submit( + _run_project_collecting_logs, + project_dir, + ): project_dir + for project_dir in project_dirs + } + for future in as_completed(future_to_project): + project_dir = future_to_project[future] + try: + summary, log_lines = future.result() + except Exception as exc: + failures.append((project_dir, str(exc))) + print(f"FAILED {project_dir}: {exc}") + if not keep_going: + for pending in future_to_project: + if pending is not future: + pending.cancel() + break + continue + for line in log_lines: + print(line) + _print_summary(summary) + if failures: + print("") + print( + "Cluster dynamics ML batch completed with " + f"{len(failures)} failure(s)." + ) + return 1 + print("") + print("Cluster dynamics ML batch complete") + return 0 + + +def _run_project_collecting_logs(project_dir: Path): + log_lines: list[str] = [] + + def log(message: str) -> None: + log_lines.append(f"[{project_dir.name}] {message}") + + run_file = default_clusterdynamicsml_run_file_path(project_dir) + config = load_clusterdynamicsml_run_config(run_file) + summary = run_clusterdynamicsml_run_config( + project_dir, + config, + run_file_path=run_file, + log_callback=log, + ) + return summary, log_lines + + +def _resolve_run_file(project_dir: Path, run_file: Path | None) -> Path: + if run_file is None: + return default_clusterdynamicsml_run_file_path(project_dir) + return Path(run_file).expanduser().resolve() + + +def _print_summary(summary) -> None: + print("") + print("Cluster dynamics ML CLI run complete") + print(f"Frames folder: {summary.frames_dir}") + print(f"Output dataset: {summary.output_file}") + print(f"Frames analyzed: {summary.result.dynamics_result.analyzed_frames}") + print(f"Time bins: {summary.result.dynamics_result.bin_count}") + print( + f"Training observations: {len(summary.result.training_observations)}" + ) + print(f"Predictions: {len(summary.result.predictions)}") + print(f"Files written: {summary.written_count}") + if summary.project_file is not None: + print(f"Project file: {summary.project_file}") + + +__all__ = ["build_parser", "main"] diff --git a/src/saxshell/clusterdynamicsml/run_config.py b/src/saxshell/clusterdynamicsml/run_config.py new file mode 100644 index 0000000..8db0c73 --- /dev/null +++ b/src/saxshell/clusterdynamicsml/run_config.py @@ -0,0 +1,628 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Callable + +from saxshell.cluster import ( + SEARCH_MODE_KDTREE, + PairCutoffDefinitions, + PDBShellReferenceDefinition, + normalize_pair_cutoffs, + normalize_search_mode, +) +from saxshell.cluster.workflow import ( + example_atom_type_definitions, + example_pair_cutoff_definitions, +) +from saxshell.clusterdynamics.run_config import ( + coerce_atom_type_definitions, + coerce_box_dimensions, + coerce_int_tuple, + coerce_pair_cutoff_definitions, + optional_float, + optional_positive_float, + optional_text, + path_text_for_run_config, + resolve_run_config_path, + serialize_atom_type_definitions, + serialize_pair_cutoff_definitions, +) +from saxshell.saxs.project_manager import ( + SAXSProjectManager, + build_project_paths, +) +from saxshell.structure import ( + AtomTypeDefinitions, + normalize_atom_type_definitions, +) + +from .dataset import ( + SavedClusterDynamicsMLDataset, + save_cluster_dynamicsai_dataset, +) +from .workflow import ClusterDynamicsMLResult, ClusterDynamicsMLWorkflow + +DEFAULT_RUN_FILE_NAME = "cluster_dynamics_ml_cli_run.json" +RUN_CONFIG_VERSION = 1 +ClusterDynamicsMLRunLogCallback = Callable[[str], None] + + +@dataclass(slots=True) +class ClusterDynamicsMLRunConfig: + frames_dir: str + output_file: str | None + clusters_dir: str | None = None + project_dir: str | None = None + experimental_data_file: str | None = None + energy_file: str | None = None + atom_type_definitions: AtomTypeDefinitions = field( + default_factory=example_atom_type_definitions + ) + pair_cutoff_definitions: PairCutoffDefinitions = field( + default_factory=example_pair_cutoff_definitions + ) + box_dimensions: tuple[float, float, float] | None = None + use_pbc: bool = False + default_cutoff: float | None = None + shell_levels: tuple[int, ...] = () + shared_shells: bool = False + include_shell_atoms_in_stoichiometry: bool = False + search_mode: str = SEARCH_MODE_KDTREE + shell_reference_definitions: tuple[PDBShellReferenceDefinition, ...] = () + folder_start_time_fs: float | None = None + first_frame_time_fs: float = 0.0 + frame_timestep_fs: float = 0.5 + frames_per_colormap_timestep: int = 1 + analysis_start_fs: float | None = None + analysis_stop_fs: float | None = None + target_node_counts: tuple[int, ...] = (4, 5) + candidates_per_size: int = 3 + prediction_population_share_threshold: float = 0.02 + q_min: float | None = 0.02 + q_max: float | None = 1.20 + q_points: int = 250 + created_at: str = field( + default_factory=lambda: datetime.now().isoformat(timespec="seconds") + ) + + def to_dict(self) -> dict[str, object]: + return { + "version": RUN_CONFIG_VERSION, + "created_at": self.created_at, + "frames_dir": self.frames_dir, + "output_file": self.output_file, + "clusters_dir": self.clusters_dir, + "project_dir": self.project_dir, + "experimental_data_file": self.experimental_data_file, + "energy_file": self.energy_file, + "atom_type_definitions": serialize_atom_type_definitions( + self.atom_type_definitions + ), + "pair_cutoff_definitions": serialize_pair_cutoff_definitions( + self.pair_cutoff_definitions + ), + "box_dimensions": self.box_dimensions, + "use_pbc": bool(self.use_pbc), + "default_cutoff": self.default_cutoff, + "shell_levels": [int(level) for level in self.shell_levels], + "shared_shells": bool(self.shared_shells), + "include_shell_atoms_in_stoichiometry": bool( + self.include_shell_atoms_in_stoichiometry + ), + "search_mode": normalize_search_mode(self.search_mode), + "shell_reference_definitions": [ + serialize_shell_reference_definition(definition) + for definition in self.shell_reference_definitions + ], + "folder_start_time_fs": self.folder_start_time_fs, + "first_frame_time_fs": float(self.first_frame_time_fs), + "frame_timestep_fs": float(self.frame_timestep_fs), + "frames_per_colormap_timestep": int( + self.frames_per_colormap_timestep + ), + "analysis_start_fs": self.analysis_start_fs, + "analysis_stop_fs": self.analysis_stop_fs, + "target_node_counts": [ + int(value) for value in self.target_node_counts + ], + "candidates_per_size": int(self.candidates_per_size), + "prediction_population_share_threshold": float( + self.prediction_population_share_threshold + ), + "q_min": self.q_min, + "q_max": self.q_max, + "q_points": int(self.q_points), + } + + @classmethod + def from_dict( + cls, + payload: dict[str, object], + ) -> "ClusterDynamicsMLRunConfig": + frames_dir = str(payload.get("frames_dir", "")).strip() + if not frames_dir: + raise ValueError( + "Cluster dynamics ML run file is missing frames_dir." + ) + return cls( + frames_dir=frames_dir, + output_file=optional_text(payload.get("output_file")), + clusters_dir=optional_text(payload.get("clusters_dir")), + project_dir=optional_text(payload.get("project_dir")), + experimental_data_file=optional_text( + payload.get("experimental_data_file") + ), + energy_file=optional_text(payload.get("energy_file")), + atom_type_definitions=coerce_atom_type_definitions( + payload.get("atom_type_definitions") + ), + pair_cutoff_definitions=coerce_pair_cutoff_definitions( + payload.get("pair_cutoff_definitions") + ), + box_dimensions=coerce_box_dimensions( + payload.get("box_dimensions") + ), + use_pbc=bool(payload.get("use_pbc", False)), + default_cutoff=optional_positive_float( + payload.get("default_cutoff") + ), + shell_levels=coerce_int_tuple(payload.get("shell_levels")), + shared_shells=bool(payload.get("shared_shells", False)), + include_shell_atoms_in_stoichiometry=bool( + payload.get("include_shell_atoms_in_stoichiometry", False) + ), + search_mode=normalize_search_mode( + str(payload.get("search_mode", SEARCH_MODE_KDTREE)) + ), + shell_reference_definitions=coerce_shell_reference_definitions( + payload.get("shell_reference_definitions") + ), + folder_start_time_fs=optional_float( + payload.get("folder_start_time_fs") + ), + first_frame_time_fs=float(payload.get("first_frame_time_fs", 0.0)), + frame_timestep_fs=float(payload.get("frame_timestep_fs", 0.5)), + frames_per_colormap_timestep=max( + int(payload.get("frames_per_colormap_timestep", 1)), + 1, + ), + analysis_start_fs=optional_float(payload.get("analysis_start_fs")), + analysis_stop_fs=optional_float(payload.get("analysis_stop_fs")), + target_node_counts=coerce_int_tuple( + payload.get("target_node_counts") + ) + or (4, 5), + candidates_per_size=max( + int(payload.get("candidates_per_size", 3)), 1 + ), + prediction_population_share_threshold=max( + float( + payload.get("prediction_population_share_threshold", 0.02) + ), + 0.0, + ), + q_min=optional_float(payload.get("q_min")), + q_max=optional_float(payload.get("q_max")), + q_points=max(int(payload.get("q_points", 250)), 10), + created_at=str(payload.get("created_at", "")).strip() + or datetime.now().isoformat(timespec="seconds"), + ) + + +@dataclass(slots=True, frozen=True) +class ClusterDynamicsMLRunExecutionSummary: + project_dir: Path + run_file_path: Path | None + frames_dir: Path + output_file: Path + result: ClusterDynamicsMLResult + saved_dataset: SavedClusterDynamicsMLDataset + project_file: Path | None + + @property + def written_count(self) -> int: + return len(self.saved_dataset.written_files) + + +def default_clusterdynamicsml_run_file_path(project_dir: str | Path) -> Path: + return Path(project_dir).expanduser().resolve() / DEFAULT_RUN_FILE_NAME + + +def save_clusterdynamicsml_run_config( + output_path: str | Path, + config: ClusterDynamicsMLRunConfig, +) -> Path: + path = Path(output_path).expanduser().resolve() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(config.to_dict(), indent=2) + "\n", + encoding="utf-8", + ) + return path + + +def load_clusterdynamicsml_run_config( + run_file_path: str | Path, +) -> ClusterDynamicsMLRunConfig: + path = Path(run_file_path).expanduser().resolve() + payload = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError( + "Cluster dynamics ML run file must contain a JSON object: " + f"{path}" + ) + return ClusterDynamicsMLRunConfig.from_dict(payload) + + +def build_clusterdynamicsml_run_config( + *, + project_dir: str | Path, + frames_dir: str | Path, + output_file: str | Path | None, + clusters_dir: str | Path | None = None, + experimental_data_file: str | Path | None = None, + energy_file: str | Path | None = None, + atom_type_definitions: AtomTypeDefinitions, + pair_cutoff_definitions: PairCutoffDefinitions, + box_dimensions: tuple[float, float, float] | None = None, + use_pbc: bool = False, + default_cutoff: float | None = None, + shell_levels: tuple[int, ...] = (), + shared_shells: bool = False, + include_shell_atoms_in_stoichiometry: bool = False, + search_mode: str = SEARCH_MODE_KDTREE, + shell_reference_definitions: tuple[PDBShellReferenceDefinition, ...] = (), + folder_start_time_fs: float | None = None, + first_frame_time_fs: float = 0.0, + frame_timestep_fs: float = 0.5, + frames_per_colormap_timestep: int = 1, + analysis_start_fs: float | None = None, + analysis_stop_fs: float | None = None, + target_node_counts: tuple[int, ...] = (4, 5), + candidates_per_size: int = 3, + prediction_population_share_threshold: float = 0.02, + q_min: float | None = 0.02, + q_max: float | None = 1.20, + q_points: int = 250, +) -> ClusterDynamicsMLRunConfig: + return ClusterDynamicsMLRunConfig( + frames_dir=path_text_for_run_config( + frames_dir, + project_dir=project_dir, + ) + or "", + output_file=path_text_for_run_config( + output_file, + project_dir=project_dir, + ), + clusters_dir=path_text_for_run_config( + clusters_dir, + project_dir=project_dir, + ), + project_dir=path_text_for_run_config( + project_dir, + project_dir=project_dir, + ), + experimental_data_file=path_text_for_run_config( + experimental_data_file, + project_dir=project_dir, + ), + energy_file=path_text_for_run_config( + energy_file, + project_dir=project_dir, + ), + atom_type_definitions=normalize_atom_type_definitions( + atom_type_definitions + ), + pair_cutoff_definitions=normalize_pair_cutoffs( + pair_cutoff_definitions + ), + box_dimensions=box_dimensions, + use_pbc=bool(use_pbc), + default_cutoff=default_cutoff, + shell_levels=tuple(sorted({int(level) for level in shell_levels})), + shared_shells=bool(shared_shells), + include_shell_atoms_in_stoichiometry=bool( + include_shell_atoms_in_stoichiometry + ), + search_mode=normalize_search_mode(search_mode), + shell_reference_definitions=tuple(shell_reference_definitions), + folder_start_time_fs=folder_start_time_fs, + first_frame_time_fs=float(first_frame_time_fs), + frame_timestep_fs=float(frame_timestep_fs), + frames_per_colormap_timestep=max( + int(frames_per_colormap_timestep), + 1, + ), + analysis_start_fs=analysis_start_fs, + analysis_stop_fs=analysis_stop_fs, + target_node_counts=tuple( + sorted({int(value) for value in target_node_counts}) + ) + or (4, 5), + candidates_per_size=max(int(candidates_per_size), 1), + prediction_population_share_threshold=max( + float(prediction_population_share_threshold), + 0.0, + ), + q_min=q_min, + q_max=q_max, + q_points=max(int(q_points), 10), + ) + + +def suggest_clusterdynamicsml_output_file( + *, + project_dir: str | Path, + frames_dir: str | Path | None, +) -> Path: + resolved_project_dir = Path(project_dir).expanduser().resolve() + dataset_dir = ( + build_project_paths(resolved_project_dir).exported_data_dir + / "clusterdynamicsml" + ) + folder_label = "cluster_dynamics_ml" + if frames_dir is not None: + frames_path = Path(frames_dir).expanduser() + if frames_path.name: + folder_label = frames_path.name + return dataset_dir / f"{folder_label}_clusterdynamicsml.json" + + +def workflow_from_clusterdynamicsml_run_config( + *, + project_dir: str | Path, + config: ClusterDynamicsMLRunConfig, +) -> ClusterDynamicsMLWorkflow: + resolved_project_dir = Path(project_dir).expanduser().resolve() + frames_dir = resolve_run_config_path( + config.frames_dir, + project_dir=resolved_project_dir, + ) + if frames_dir is None: + raise ValueError("Cluster dynamics ML run file is missing frames_dir.") + return ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=config.atom_type_definitions, + pair_cutoff_definitions=config.pair_cutoff_definitions, + clusters_dir=resolve_run_config_path( + config.clusters_dir, + project_dir=resolved_project_dir, + ), + project_dir=resolve_run_config_path( + config.project_dir, + project_dir=resolved_project_dir, + ) + or resolved_project_dir, + experimental_data_file=resolve_run_config_path( + config.experimental_data_file, + project_dir=resolved_project_dir, + ), + box_dimensions=config.box_dimensions, + use_pbc=config.use_pbc, + default_cutoff=config.default_cutoff, + shell_levels=config.shell_levels, + shared_shells=config.shared_shells, + include_shell_atoms_in_stoichiometry=( + config.include_shell_atoms_in_stoichiometry + ), + search_mode=config.search_mode, + pdb_shell_reference_definitions=config.shell_reference_definitions, + folder_start_time_fs=config.folder_start_time_fs, + first_frame_time_fs=config.first_frame_time_fs, + frame_timestep_fs=config.frame_timestep_fs, + frames_per_colormap_timestep=config.frames_per_colormap_timestep, + analysis_start_fs=config.analysis_start_fs, + analysis_stop_fs=config.analysis_stop_fs, + energy_file=resolve_run_config_path( + config.energy_file, + project_dir=resolved_project_dir, + ), + target_node_counts=config.target_node_counts, + candidates_per_size=config.candidates_per_size, + prediction_population_share_threshold=( + config.prediction_population_share_threshold + ), + q_min=config.q_min, + q_max=config.q_max, + q_points=config.q_points, + ) + + +def preview_clusterdynamicsml_run_config( + *, + project_dir: str | Path, + config: ClusterDynamicsMLRunConfig, +) -> dict[str, object]: + workflow = workflow_from_clusterdynamicsml_run_config( + project_dir=project_dir, + config=config, + ) + preview = workflow.preview_selection() + return { + "dynamics": preview.dynamics_preview.to_dict(), + "clusters_dir": ( + None if preview.clusters_dir is None else str(preview.clusters_dir) + ), + "project_dir": ( + None if preview.project_dir is None else str(preview.project_dir) + ), + "experimental_data_path": ( + None + if preview.experimental_data_path is None + else str(preview.experimental_data_path) + ), + "structure_label_count": int(preview.structure_label_count), + "total_structure_files": int(preview.total_structure_files), + "observed_node_counts": list(preview.observed_node_counts), + "target_node_counts": list(preview.target_node_counts), + "warnings": list(preview.warnings), + } + + +def run_clusterdynamicsml_run_config( + project_dir: str | Path, + config: ClusterDynamicsMLRunConfig, + *, + run_file_path: str | Path | None = None, + log_callback: ClusterDynamicsMLRunLogCallback | None = None, +) -> ClusterDynamicsMLRunExecutionSummary: + resolved_project_dir = Path(project_dir).expanduser().resolve() + frames_dir = resolve_run_config_path( + config.frames_dir, + project_dir=resolved_project_dir, + ) + if frames_dir is None: + raise ValueError("Cluster dynamics ML run file is missing frames_dir.") + output_file = resolve_run_config_path( + config.output_file, + project_dir=resolved_project_dir, + ) + if output_file is None: + output_file = suggest_clusterdynamicsml_output_file( + project_dir=resolved_project_dir, + frames_dir=frames_dir, + ) + workflow = workflow_from_clusterdynamicsml_run_config( + project_dir=resolved_project_dir, + config=config, + ) + _emit_log(log_callback, f"Frames folder: {frames_dir}") + _emit_log(log_callback, f"Output dataset: {output_file}") + result = workflow.analyze(progress_callback=log_callback) + saved = save_cluster_dynamicsai_dataset( + result, + output_file, + analysis_settings=config.to_dict(), + ) + project_file = _register_project_references( + resolved_project_dir, + frames_dir=frames_dir, + clusters_dir=resolve_run_config_path( + config.clusters_dir, + project_dir=resolved_project_dir, + ), + energy_file=resolve_run_config_path( + config.energy_file, + project_dir=resolved_project_dir, + ), + experimental_data_file=resolve_run_config_path( + config.experimental_data_file, + project_dir=resolved_project_dir, + ), + ) + _emit_log( + log_callback, + f"Saved cluster-dynamics ML dataset: {saved.dataset_file}", + ) + return ClusterDynamicsMLRunExecutionSummary( + project_dir=resolved_project_dir, + run_file_path=( + None if run_file_path is None else Path(run_file_path).resolve() + ), + frames_dir=frames_dir, + output_file=saved.dataset_file, + result=result, + saved_dataset=saved, + project_file=project_file, + ) + + +def serialize_shell_reference_definition( + definition: PDBShellReferenceDefinition, +) -> dict[str, str | None]: + return { + "shell_element": definition.shell_element, + "shell_residue": definition.shell_residue, + "reference_name": definition.reference_name, + "backbone_atom1_name": definition.backbone_atom1_name, + "backbone_atom2_name": definition.backbone_atom2_name, + } + + +def coerce_shell_reference_definitions( + value: object, +) -> tuple[PDBShellReferenceDefinition, ...]: + if not isinstance(value, list): + return () + definitions: list[PDBShellReferenceDefinition] = [] + for entry in value: + if not isinstance(entry, dict): + continue + shell_element = str(entry.get("shell_element", "")).strip().title() + reference_name = str(entry.get("reference_name", "")).strip() + if not shell_element or not reference_name: + continue + definitions.append( + PDBShellReferenceDefinition( + shell_element=shell_element, + shell_residue=optional_text(entry.get("shell_residue")), + reference_name=reference_name, + backbone_atom1_name=optional_text( + entry.get("backbone_atom1_name") + ), + backbone_atom2_name=optional_text( + entry.get("backbone_atom2_name") + ), + ) + ) + return tuple(definitions) + + +def _register_project_references( + project_dir: Path, + *, + frames_dir: Path, + clusters_dir: Path | None, + energy_file: Path | None, + experimental_data_file: Path | None, +) -> Path | None: + project_file = build_project_paths(project_dir).project_file + if not project_file.is_file(): + return None + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.frames_dir = str(Path(frames_dir).expanduser().resolve()) + settings.clusters_dir = ( + None + if clusters_dir is None + else str(Path(clusters_dir).expanduser().resolve()) + ) + settings.energy_file = ( + None + if energy_file is None + else str(Path(energy_file).expanduser().resolve()) + ) + if experimental_data_file is not None: + settings.experimental_data_path = str( + Path(experimental_data_file).expanduser().resolve() + ) + return manager.save_project(settings) + + +def _emit_log( + callback: ClusterDynamicsMLRunLogCallback | None, + message: str, +) -> None: + if callback is not None: + callback(str(message).strip()) + + +__all__ = [ + "DEFAULT_RUN_FILE_NAME", + "ClusterDynamicsMLRunConfig", + "ClusterDynamicsMLRunExecutionSummary", + "build_clusterdynamicsml_run_config", + "coerce_shell_reference_definitions", + "default_clusterdynamicsml_run_file_path", + "load_clusterdynamicsml_run_config", + "preview_clusterdynamicsml_run_config", + "run_clusterdynamicsml_run_config", + "save_clusterdynamicsml_run_config", + "serialize_shell_reference_definition", + "suggest_clusterdynamicsml_output_file", + "workflow_from_clusterdynamicsml_run_config", +] diff --git a/src/saxshell/clusterdynamicsml/ui/__init__.py b/src/saxshell/clusterdynamicsml/ui/__init__.py index c53ce12..3e8dcf4 100644 --- a/src/saxshell/clusterdynamicsml/ui/__init__.py +++ b/src/saxshell/clusterdynamicsml/ui/__init__.py @@ -2,5 +2,14 @@ ClusterDynamicsMLMainWindow, launch_clusterdynamicsml_ui, ) +from .run_file_window import ( + ClusterDynamicsMLRunFileWindow, + launch_clusterdynamicsml_run_file_ui, +) -__all__ = ["ClusterDynamicsMLMainWindow", "launch_clusterdynamicsml_ui"] +__all__ = [ + "ClusterDynamicsMLMainWindow", + "ClusterDynamicsMLRunFileWindow", + "launch_clusterdynamicsml_run_file_ui", + "launch_clusterdynamicsml_ui", +] diff --git a/src/saxshell/clusterdynamicsml/ui/main_window.py b/src/saxshell/clusterdynamicsml/ui/main_window.py index c834960..a11316f 100644 --- a/src/saxshell/clusterdynamicsml/ui/main_window.py +++ b/src/saxshell/clusterdynamicsml/ui/main_window.py @@ -90,6 +90,7 @@ ) from .plot_panel import ( ClusterDynamicsMLHistogramPanel, + ClusterDynamicsMLLifetimeDistributionWindow, ClusterDynamicsMLPlotPanel, ) @@ -674,6 +675,9 @@ def __init__( self._active_job_config: ClusterDynamicsMLJobConfig | None = None self._active_job_preview: ClusterDynamicsMLPreview | None = None self._auto_detected_energy_file: Path | None = None + self._lifetime_distribution_window: ( + ClusterDynamicsMLLifetimeDistributionWindow | None + ) = None self._history_panel_expanded = True self._history_expanded_splitter_size = _HISTORY_EXPANDED_DEFAULT_HEIGHT self._suspend_preview_refresh = False @@ -744,6 +748,9 @@ def closeEvent(self, event) -> None: ) event.ignore() return + if self._lifetime_distribution_window is not None: + self._lifetime_distribution_window.close() + self._lifetime_distribution_window = None app = QApplication.instance() if self._app_event_filter_installed and app is not None: app.removeEventFilter(self) @@ -974,6 +981,26 @@ def _build_ui(self) -> None: history_button_row.addStretch(1) history_content_layout.addLayout(history_button_row) history_layout.addWidget(self.history_content, stretch=1) + self.lifetime_tab = QWidget() + lifetime_layout = QVBoxLayout(self.lifetime_tab) + lifetime_layout.setContentsMargins(0, 0, 0, 0) + lifetime_layout.setSpacing(8) + lifetime_button_row = QHBoxLayout() + lifetime_button_row.setContentsMargins(0, 0, 0, 0) + self.lifetime_distribution_button = QPushButton( + "Plot Lifetime Distributions" + ) + self.lifetime_distribution_button.setToolTip( + "Open completed cluster lifetime histograms for each observed " + "stoichiometry, with spread metrics and Lorentzian/Cauchy shape " + "diagnostics." + ) + self.lifetime_distribution_button.clicked.connect( + self.open_lifetime_distribution_window + ) + lifetime_button_row.addWidget(self.lifetime_distribution_button) + lifetime_button_row.addStretch(1) + lifetime_layout.addLayout(lifetime_button_row) self.lifetime_table = self._build_table( ( "Type", @@ -996,6 +1023,7 @@ def _build_ui(self) -> None: "Notes", ) ) + lifetime_layout.addWidget(self.lifetime_table, stretch=1) self.debye_waller_table = self._build_table( ( "Type", @@ -1018,7 +1046,7 @@ def _build_ui(self) -> None: self.combined_histogram_panel = self.histogram_panel self.predicted_structures_plot_panel = self.saxs_panel self.results_tabs.addTab(self.summary_tab, "Summary") - self.results_tabs.addTab(self.lifetime_table, "Lifetimes") + self.results_tabs.addTab(self.lifetime_tab, "Lifetimes") self.results_tabs.addTab(self.debye_waller_table, "Debye-Waller") self.results_tabs.addTab(self.histogram_panel, "Histograms") self.results_tabs.addTab(self.saxs_panel, "SAXS") @@ -1093,6 +1121,7 @@ def _build_ui(self) -> None: "representative structures." ) self._set_frame_format(None) + self._set_lifetime_distribution_result(None) self._update_history_controls() def _load_shell_reference_library_entries(self) -> None: @@ -1152,6 +1181,7 @@ def run_analysis(self) -> None: self.dynamics_plot_panel.set_result(None) self.histogram_panel.set_result(None) self.saxs_panel.set_result(None) + self._set_lifetime_distribution_result(None) self.summary_box.clear() self.lifetime_table.setRowCount(0) self.debye_waller_table.setRowCount(0) @@ -1410,6 +1440,7 @@ def _on_run_finished(self, result: ClusterDynamicsMLResult) -> None: self.dynamics_plot_panel.set_result(result.dynamics_result) self.histogram_panel.set_result(result) self.saxs_panel.set_result(result) + self._set_lifetime_distribution_result(result) self.run_panel.progress_bar.setRange( 0, max(result.dynamics_result.analyzed_frames, 1) ) @@ -1603,6 +1634,30 @@ def save_lifetime_table(self) -> None: ) self.statusBar().showMessage(f"Saved lifetime table to {saved_path}") + def open_lifetime_distribution_window(self) -> None: + if self._last_result is None: + self._show_error( + "Run an analysis or load a saved result before plotting " + "lifetime distributions." + ) + return + if self._lifetime_distribution_window is None: + self._lifetime_distribution_window = ( + ClusterDynamicsMLLifetimeDistributionWindow(parent=self) + ) + self._lifetime_distribution_window.set_result(self._last_result) + self._lifetime_distribution_window.show() + self._lifetime_distribution_window.raise_() + self._lifetime_distribution_window.activateWindow() + + def _set_lifetime_distribution_result( + self, + result: ClusterDynamicsMLResult | None, + ) -> None: + self.lifetime_distribution_button.setEnabled(result is not None) + if self._lifetime_distribution_window is not None: + self._lifetime_distribution_window.set_result(result) + def save_powerpoint_report(self) -> None: if self._last_result is None: self._show_error( @@ -1805,6 +1860,7 @@ def _on_frames_dir_changed(self, frames_dir: Path | None) -> None: self.dynamics_plot_panel.set_result(None) self.histogram_panel.set_result(None) self.saxs_panel.set_result(None) + self._set_lifetime_distribution_result(None) self.summary_box.clear() self.lifetime_table.setRowCount(0) self.debye_waller_table.setRowCount(0) @@ -2926,6 +2982,7 @@ def _apply_loaded_dataset( self.dynamics_plot_panel.set_result(loaded.result.dynamics_result) self.histogram_panel.set_result(loaded.result) self.saxs_panel.set_result(loaded.result) + self._set_lifetime_distribution_result(loaded.result) self.run_panel.set_selection_summary( self._format_preview_text(loaded.result.preview) ) diff --git a/src/saxshell/clusterdynamicsml/ui/plot_panel.py b/src/saxshell/clusterdynamicsml/ui/plot_panel.py index 7b86f04..663b3ea 100644 --- a/src/saxshell/clusterdynamicsml/ui/plot_panel.py +++ b/src/saxshell/clusterdynamicsml/ui/plot_panel.py @@ -1,6 +1,8 @@ from __future__ import annotations +import math from collections import Counter +from dataclasses import dataclass from pathlib import Path import numpy as np @@ -15,12 +17,16 @@ QComboBox, QHBoxLayout, QLabel, + QMainWindow, QPushButton, + QSpinBox, + QTextEdit, QVBoxLayout, QWidget, ) from saxshell.cluster.clusternetwork import stoichiometry_label +from saxshell.clusterdynamics.workflow import _summarize_series_lifetimes from saxshell.clusterdynamicsml.workflow import ( ClusterDynamicsMLResult, _resolved_population_weights, @@ -31,7 +37,10 @@ list_secondary_filter_elements, plot_md_prior_histogram, ) -from saxshell.saxs.stoichiometry import parse_stoich_label +from saxshell.saxs.stoichiometry import ( + format_stoich_for_axis, + parse_stoich_label, +) _EXPERIMENTAL_COLOR = "#111111" _OBSERVED_MODEL_COLOR = "#1f77b4" @@ -44,6 +53,29 @@ ("Solvent Sort - Atom Fraction", "solvent_sort_atom_fraction"), ) _STRUCTURE_FILE_SUFFIXES = {".xyz", ".pdb"} +_COMPLETED_LIFETIME_COLOR = "#2e86ab" +_TRUNCATED_LIFETIME_COLOR = "#d95f02" +_LORENTZIAN_COLOR = "#7b3294" + + +@dataclass(slots=True) +class ClusterLifetimeDistribution: + label: str + cluster_size: int + completed_lifetimes_fs: np.ndarray + window_truncated_lifetimes_fs: np.ndarray + mean_lifetime_fs: float | None + std_lifetime_fs: float | None + completed_lifetime_count: int + window_truncated_lifetime_count: int + + +@dataclass(slots=True) +class _LorentzianFit: + center: float + gamma: float + r_squared: float | None + interpretation: str class ClusterDynamicsMLHistogramPanel(QWidget): @@ -268,6 +300,217 @@ def _update_population_control_state(self) -> None: self.population_combo.setVisible(visible) +class ClusterDynamicsMLLifetimeDistributionPanel(QWidget): + """Plot completed lifetime distributions for each observed + stoichiometry.""" + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._result: ClusterDynamicsMLResult | None = None + self._distributions: tuple[ClusterLifetimeDistribution, ...] = () + self._build_ui() + self.refresh_plot() + + def _build_ui(self) -> None: + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(8) + + controls_widget = QWidget() + controls = QHBoxLayout(controls_widget) + controls.setContentsMargins(0, 0, 0, 0) + controls.setSpacing(8) + + controls.addWidget(QLabel("Units")) + self.unit_combo = QComboBox() + self.unit_combo.addItem("fs", "fs") + self.unit_combo.addItem("ps", "ps") + self.unit_combo.currentIndexChanged.connect( + lambda _index: self.refresh_plot() + ) + controls.addWidget(self.unit_combo) + + controls.addWidget(QLabel("Bins")) + self.bin_spin = QSpinBox() + self.bin_spin.setRange(1, 80) + self.bin_spin.setValue(12) + self.bin_spin.valueChanged.connect(lambda _value: self.refresh_plot()) + controls.addWidget(self.bin_spin) + + self.include_truncated_checkbox = QCheckBox( + "Include window-truncated lifetimes" + ) + self.include_truncated_checkbox.setToolTip( + "Window-truncated lifetimes are right- or left-censored by the " + "selected time window. They are useful for seeing persistence, " + "but can bias the distribution shape." + ) + self.include_truncated_checkbox.toggled.connect( + lambda _checked: self.refresh_plot() + ) + controls.addWidget(self.include_truncated_checkbox) + controls.addStretch(1) + layout.addWidget(controls_widget) + + self.figure = Figure(figsize=(10.8, 7.2)) + self.canvas = FigureCanvas(self.figure) + layout.addWidget(NavigationToolbar(self.canvas, self)) + layout.addWidget(self.canvas, stretch=1) + + self.summary_box = QTextEdit() + self.summary_box.setReadOnly(True) + self.summary_box.setMaximumHeight(170) + self.summary_box.setMinimumHeight(96) + layout.addWidget(self.summary_box) + + def set_result(self, result: ClusterDynamicsMLResult | None) -> None: + self._result = result + self._distributions = ( + () + if result is None + else build_cluster_lifetime_distributions(result) + ) + self.refresh_plot() + + def refresh_plot(self) -> None: + self.figure.clear() + if self._result is None: + axis = self.figure.add_subplot(111) + axis.text( + 0.5, + 0.5, + "Run the prediction workflow or open a saved result to plot\n" + "cluster lifetime distributions by stoichiometry.", + ha="center", + va="center", + transform=axis.transAxes, + ) + axis.set_axis_off() + self.summary_box.setPlainText("") + self.canvas.draw_idle() + return + + if not self._distributions: + axis = self.figure.add_subplot(111) + axis.text( + 0.5, + 0.5, + "No lifetime distributions are available for this result.", + ha="center", + va="center", + transform=axis.transAxes, + ) + axis.set_axis_off() + self.summary_box.setPlainText("") + self.canvas.draw_idle() + return + + unit = str(self.unit_combo.currentData() or "fs") + include_truncated = bool(self.include_truncated_checkbox.isChecked()) + requested_bins = int(self.bin_spin.value()) + distribution_count = len(self._distributions) + column_count = min(3, max(1, math.ceil(math.sqrt(distribution_count)))) + row_count = int(math.ceil(distribution_count / column_count)) + self.figure.set_size_inches( + max(10.8, 4.2 * column_count), + max(6.4, 3.3 * row_count), + forward=True, + ) + axes = self.figure.subplots(row_count, column_count, squeeze=False) + summary_lines: list[str] = [] + for index, distribution in enumerate(self._distributions): + row = index // column_count + column = index % column_count + summary_lines.append( + _plot_lifetime_distribution_axis( + axes[row][column], + distribution, + unit=unit, + requested_bins=requested_bins, + include_truncated=include_truncated, + ) + ) + + for index in range(distribution_count, row_count * column_count): + row = index // column_count + column = index % column_count + axes[row][column].set_axis_off() + + self.figure.suptitle( + "Cluster Lifetime Distributions by Stoichiometry", + y=0.995, + ) + self.figure.tight_layout(rect=(0.0, 0.0, 1.0, 0.965)) + self.summary_box.setPlainText("\n".join(summary_lines)) + self.canvas.draw_idle() + + +class ClusterDynamicsMLLifetimeDistributionWindow(QMainWindow): + """Separate window for lifetime-distribution diagnostics.""" + + def __init__( + self, + result: ClusterDynamicsMLResult | None = None, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self.setWindowTitle("Cluster Lifetime Distributions") + self.resize(1220, 850) + self.panel = ClusterDynamicsMLLifetimeDistributionPanel(self) + self.setCentralWidget(self.panel) + self.set_result(result) + + def set_result(self, result: ClusterDynamicsMLResult | None) -> None: + self.panel.set_result(result) + + +def build_cluster_lifetime_distributions( + result: ClusterDynamicsMLResult, +) -> tuple[ClusterLifetimeDistribution, ...]: + dynamics_result = result.dynamics_result + frame_count_matrix = np.asarray( + dynamics_result.frame_count_matrix, + dtype=float, + ) + frame_times_fs = dynamics_result.frame_times_fs + distributions: list[ClusterLifetimeDistribution] = [] + if frame_count_matrix.ndim != 2: + return () + + for row_index, label in enumerate(dynamics_result.cluster_labels): + if row_index >= frame_count_matrix.shape[0]: + continue + metrics = _summarize_series_lifetimes( + frame_count_matrix[row_index, :], + frame_times_fs=frame_times_fs, + observation_start_fs=dynamics_result.preview.analysis_start_fs, + observation_stop_fs=dynamics_result.preview.analysis_stop_fs, + ) + distributions.append( + ClusterLifetimeDistribution( + label=str(label), + cluster_size=int( + dynamics_result.cluster_sizes.get(str(label), 0) + ), + completed_lifetimes_fs=np.asarray( + metrics.completed_lifetimes_fs, + dtype=float, + ), + window_truncated_lifetimes_fs=np.asarray( + metrics.window_truncated_lifetimes_fs, + dtype=float, + ), + mean_lifetime_fs=metrics.mean_lifetime_fs, + std_lifetime_fs=metrics.std_lifetime_fs, + completed_lifetime_count=len(metrics.completed_lifetimes_fs), + window_truncated_lifetime_count=len( + metrics.window_truncated_lifetimes_fs + ), + ) + ) + return tuple(distributions) + + class ClusterDynamicsMLPlotPanel(QWidget): """Plot observed-only, Predicted Structures, and component SAXS traces.""" @@ -1367,6 +1610,325 @@ def _predicted_structure_label( return stoichiometry_label(primary_counts) +def _plot_lifetime_distribution_axis( + axis, + distribution: ClusterLifetimeDistribution, + *, + unit: str, + requested_bins: int, + include_truncated: bool, +) -> str: + completed = _scaled_lifetimes( + distribution.completed_lifetimes_fs, + unit=unit, + ) + truncated = _scaled_lifetimes( + distribution.window_truncated_lifetimes_fs, + unit=unit, + ) + fitted_values = ( + np.concatenate([completed, truncated]) + if include_truncated and truncated.size + else completed + ) + finite_values = fitted_values[np.isfinite(fitted_values)] + unit_label = "ps" if unit == "ps" else "fs" + axis.set_title( + f"{format_stoich_for_axis(distribution.label)} " + f"(n={completed.size})" + ) + axis.set_xlabel(f"Lifetime ({unit_label})") + axis.set_ylabel("Probability density") + + if finite_values.size == 0: + axis.text( + 0.5, + 0.5, + "No completed lifetimes", + ha="center", + va="center", + transform=axis.transAxes, + ) + return _lifetime_distribution_summary_line( + distribution, + values=finite_values, + unit=unit, + fit=None, + ) + + bin_edges = _histogram_bin_edges(finite_values, requested_bins) + if completed.size: + axis.hist( + completed, + bins=bin_edges, + density=True, + alpha=0.72, + color=_COMPLETED_LIFETIME_COLOR, + edgecolor="white", + linewidth=0.8, + label="Completed", + ) + if include_truncated and truncated.size: + axis.hist( + truncated, + bins=bin_edges, + density=True, + alpha=0.38, + color=_TRUNCATED_LIFETIME_COLOR, + edgecolor="white", + linewidth=0.8, + hatch="//", + label="Window-truncated", + ) + + mean_value = float(np.mean(finite_values)) + median_value = float(np.median(finite_values)) + q1, q3 = np.percentile(finite_values, [25.0, 75.0]) + axis.axvspan( + float(q1), + float(q3), + color="#4d4d4d", + alpha=0.10, + label="IQR", + ) + axis.axvline( + mean_value, + color="#222222", + linewidth=1.1, + linestyle="--", + label="Mean", + ) + axis.axvline( + median_value, + color="#005f73", + linewidth=1.1, + linestyle=":", + label="Median", + ) + + fit = _fit_lorentzian_to_lifetimes(finite_values) + if fit is not None: + x_min = float(np.min(finite_values)) + x_max = float(np.max(finite_values)) + if x_max <= x_min: + x_min = max(0.0, fit.center - fit.gamma * 3.0) + x_max = fit.center + fit.gamma * 3.0 + else: + padding = (x_max - x_min) * 0.10 + x_min = max(0.0, x_min - padding) + x_max += padding + x_values = np.linspace(x_min, x_max, 256) + axis.plot( + x_values, + _lorentzian_pdf(x_values, fit.center, fit.gamma), + color=_LORENTZIAN_COLOR, + linewidth=1.45, + label="Lorentzian fit", + ) + + info_lines = [ + f"completed {completed.size}", + f"truncated {truncated.size}", + f"mean {_format_float(mean_value)} +/- " + f"{_format_float(float(np.std(finite_values, ddof=0)))} {unit_label}", + f"median {_format_float(median_value)} {unit_label}", + f"IQR {_format_float(float(q3 - q1))} {unit_label}", + ] + if fit is None: + info_lines.append("Lorentzian: too few/degenerate") + else: + r_squared = ( + "n/a" + if fit.r_squared is None + else _format_float(float(fit.r_squared)) + ) + info_lines.append( + f"Lorentzian R2 {r_squared}; gamma " + f"{_format_float(fit.gamma)} {unit_label}" + ) + axis.text( + 0.02, + 0.98, + "\n".join(info_lines), + ha="left", + va="top", + fontsize=8, + transform=axis.transAxes, + bbox={ + "boxstyle": "round,pad=0.35", + "facecolor": "white", + "alpha": 0.78, + "edgecolor": "#cccccc", + }, + ) + axis.legend(fontsize=7, loc="best", framealpha=0.85) + return _lifetime_distribution_summary_line( + distribution, + values=finite_values, + unit=unit, + fit=fit, + ) + + +def _scaled_lifetimes(values: np.ndarray, *, unit: str) -> np.ndarray: + scale = 1000.0 if unit == "ps" else 1.0 + array = np.asarray(values, dtype=float) + return array[np.isfinite(array)] / scale + + +def _histogram_bin_edges(values: np.ndarray, requested_bins: int): + finite_values = np.asarray(values, dtype=float) + finite_values = finite_values[np.isfinite(finite_values)] + if finite_values.size <= 1: + return 1 + lower = float(np.min(finite_values)) + upper = float(np.max(finite_values)) + if upper <= lower: + return 1 + bin_count = min( + max(int(requested_bins), 1), max(int(finite_values.size), 1) + ) + padding = (upper - lower) * 0.04 + return np.linspace(lower - padding, upper + padding, bin_count + 1) + + +def _fit_lorentzian_to_lifetimes( + values: np.ndarray, +) -> _LorentzianFit | None: + finite_values = np.asarray(values, dtype=float) + finite_values = finite_values[np.isfinite(finite_values)] + if finite_values.size < 3: + return None + if float(np.max(finite_values)) <= float(np.min(finite_values)): + return None + + q1, center, q3 = np.percentile(finite_values, [25.0, 50.0, 75.0]) + gamma = float((q3 - q1) / 2.0) + if gamma <= 0.0: + mad = float(np.median(np.abs(finite_values - center))) + gamma = mad if mad > 0.0 else float(np.std(finite_values, ddof=0)) + if gamma <= 0.0: + return None + + bin_edges = _histogram_bin_edges( + finite_values, + requested_bins=max(3, int(math.sqrt(finite_values.size))), + ) + if isinstance(bin_edges, int): + return None + density, edges = np.histogram( + finite_values, + bins=bin_edges, + density=True, + ) + centers = (edges[:-1] + edges[1:]) / 2.0 + fitted_density = _lorentzian_pdf(centers, float(center), gamma) + finite_mask = np.isfinite(density) & np.isfinite(fitted_density) + r_squared: float | None + if np.count_nonzero(finite_mask) >= 2: + y_true = density[finite_mask] + y_fit = fitted_density[finite_mask] + residual_sum = float(np.sum((y_true - y_fit) ** 2)) + total_sum = float(np.sum((y_true - float(np.mean(y_true))) ** 2)) + r_squared = ( + None if total_sum <= 0.0 else 1.0 - residual_sum / total_sum + ) + else: + r_squared = None + return _LorentzianFit( + center=float(center), + gamma=float(gamma), + r_squared=r_squared, + interpretation=_lorentzian_interpretation( + sample_count=int(finite_values.size), + r_squared=r_squared, + ), + ) + + +def _lorentzian_pdf( + x_values: np.ndarray, + center: float, + gamma: float, +) -> np.ndarray: + safe_gamma = max(float(gamma), 1.0e-12) + normalized = ( + np.asarray(x_values, dtype=float) - float(center) + ) / safe_gamma + return 1.0 / (np.pi * safe_gamma * (1.0 + normalized**2)) + + +def _lorentzian_interpretation( + *, + sample_count: int, + r_squared: float | None, +) -> str: + if sample_count < 5: + return "limited support" + if r_squared is None: + return "shape unresolved" + if r_squared >= 0.80: + return "Lorentzian-like heavy tail" + if r_squared >= 0.50: + return "partly Lorentzian-like" + return "not clearly Lorentzian" + + +def _lifetime_distribution_summary_line( + distribution: ClusterLifetimeDistribution, + *, + values: np.ndarray, + unit: str, + fit: _LorentzianFit | None, +) -> str: + unit_label = "ps" if unit == "ps" else "fs" + if values.size == 0: + spread = "no completed lifetimes" + else: + mean_value = float(np.mean(values)) + std_value = float(np.std(values, ddof=0)) + median_value = float(np.median(values)) + q1, q3 = np.percentile(values, [25.0, 75.0]) + cv_text = ( + "n/a" + if mean_value <= 0.0 + else _format_float(std_value / mean_value) + ) + spread = ( + f"mean {_format_float(mean_value)} {unit_label}, " + f"std {_format_float(std_value)} {unit_label}, " + f"median {_format_float(median_value)} {unit_label}, " + f"IQR {_format_float(float(q3 - q1))} {unit_label}, " + f"CV {cv_text}" + ) + if fit is None: + lorentzian = "Lorentzian: too few or degenerate values" + else: + r_squared = ( + "n/a" + if fit.r_squared is None + else _format_float(float(fit.r_squared)) + ) + lorentzian = ( + "Lorentzian: " + f"center {_format_float(fit.center)} {unit_label}, " + f"gamma {_format_float(fit.gamma)} {unit_label}, " + f"R2 {r_squared}, {fit.interpretation}" + ) + return ( + f"{distribution.label}: completed " + f"{distribution.completed_lifetime_count}, window-truncated " + f"{distribution.window_truncated_lifetime_count}; {spread}; " + f"{lorentzian}" + ) + + +def _format_float(value: float) -> str: + if not np.isfinite(value): + return "n/a" + return f"{float(value):.4g}" + + def _build_saxs_model( result: ClusterDynamicsMLResult, *, diff --git a/src/saxshell/clusterdynamicsml/ui/run_file_window.py b/src/saxshell/clusterdynamicsml/ui/run_file_window.py new file mode 100644 index 0000000..491a3bf --- /dev/null +++ b/src/saxshell/clusterdynamicsml/ui/run_file_window.py @@ -0,0 +1,610 @@ +from __future__ import annotations + +import json +import sys +from pathlib import Path + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QApplication, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QMainWindow, + QMessageBox, + QPlainTextEdit, + QPushButton, + QScrollArea, + QSplitter, + QVBoxLayout, + QWidget, +) + +from saxshell.cluster import DEFAULT_CLUSTER_EXTRACTION_PRESET_NAME +from saxshell.cluster.ui.definitions_panel import ClusterDefinitionsPanel +from saxshell.cluster.workflow import ClusterWorkflow, format_box_dimensions +from saxshell.clusterdynamics.ui.main_window import ClusterDynamicsTimePanel +from saxshell.clusterdynamicsml.run_config import ( + build_clusterdynamicsml_run_config, + default_clusterdynamicsml_run_file_path, + preview_clusterdynamicsml_run_config, + save_clusterdynamicsml_run_config, + suggest_clusterdynamicsml_output_file, +) +from saxshell.clusterdynamicsml.ui.main_window import ( + ClusterDynamicsMLSettingsPanel, +) +from saxshell.saxs.project_manager import SAXSProjectManager +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) +from saxshell.xyz2pdb import list_reference_library + + +class ClusterDynamicsMLRunFileWindow(QMainWindow): + def __init__( + self, + *, + initial_project_dir: str | Path | None = None, + initial_frames_dir: str | Path | None = None, + initial_energy_file: str | Path | None = None, + initial_clusters_dir: str | Path | None = None, + initial_experimental_data_file: str | Path | None = None, + ) -> None: + super().__init__() + self._browse_start_dir = Path.home() + self._last_summary: dict[str, object] | None = None + self._last_suggested_output_file: str | None = None + + project_dir = _optional_resolved_path(initial_project_dir) + frames_dir = _optional_resolved_path(initial_frames_dir) + energy_file = _optional_resolved_path(initial_energy_file) + clusters_dir = _optional_resolved_path(initial_clusters_dir) + experimental_data_file = _optional_resolved_path( + initial_experimental_data_file + ) + if project_dir is not None: + self._browse_start_dir = project_dir + defaults = self._project_defaults(project_dir) + if frames_dir is None: + frames_dir = defaults.get("frames_dir") + if energy_file is None: + energy_file = defaults.get("energy_file") + if clusters_dir is None: + clusters_dir = defaults.get("clusters_dir") + if experimental_data_file is None: + experimental_data_file = defaults.get("experimental_data_file") + + self.setWindowTitle("Cluster Dynamics ML CLI Setup (Beta)") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1120, 840) + self._build_ui() + self.definitions_panel.load_preset( + DEFAULT_CLUSTER_EXTRACTION_PRESET_NAME + ) + self.definitions_panel.set_shell_reference_editor_enabled(True) + self._load_shell_reference_library_entries() + + if project_dir is not None: + self.project_dir_edit.setText(str(project_dir)) + self._refresh_run_file_path() + if frames_dir is not None and frames_dir.is_dir(): + self.frames_dir_edit.setText(str(frames_dir)) + self._browse_start_dir = frames_dir + if energy_file is not None and energy_file.is_file(): + self.energy_file_edit.setText(str(energy_file)) + if clusters_dir is not None and clusters_dir.is_dir(): + self.prediction_panel.set_clusters_dir( + clusters_dir, + emit_signal=False, + ) + if ( + experimental_data_file is not None + and experimental_data_file.is_file() + ): + self.prediction_panel.set_experimental_data_file( + experimental_data_file, + emit_signal=False, + ) + self._inspect_frames() + self._update_preview() + + def _build_ui(self) -> None: + central = QWidget(self) + root = QVBoxLayout(central) + root.setContentsMargins(10, 10, 10, 10) + root.setSpacing(8) + self.setCentralWidget(central) + + splitter = QSplitter(Qt.Orientation.Horizontal, self) + splitter.setChildrenCollapsible(False) + root.addWidget(splitter, stretch=1) + + left_scroll = QScrollArea(self) + left_scroll.setWidgetResizable(True) + left_panel = QWidget() + self.left_layout = QVBoxLayout(left_panel) + self.left_layout.setContentsMargins(10, 10, 10, 10) + self.left_layout.setSpacing(10) + left_scroll.setWidget(left_panel) + + right_scroll = QScrollArea(self) + right_scroll.setWidgetResizable(True) + right_panel = QWidget() + self.right_layout = QVBoxLayout(right_panel) + self.right_layout.setContentsMargins(10, 10, 10, 10) + self.right_layout.setSpacing(10) + right_scroll.setWidget(right_panel) + + splitter.addWidget(left_scroll) + splitter.addWidget(right_scroll) + splitter.setSizes([600, 520]) + + self.left_layout.addWidget(self._build_project_group()) + self.left_layout.addWidget(self._build_input_group()) + self.prediction_panel = ClusterDynamicsMLSettingsPanel() + self.prediction_panel.settings_changed.connect(self._update_preview) + self.left_layout.addWidget(self.prediction_panel) + self.definitions_panel = ClusterDefinitionsPanel() + self.definitions_panel.settings_changed.connect(self._update_preview) + self.left_layout.addWidget(self.definitions_panel) + self.time_panel = ClusterDynamicsTimePanel() + self.time_panel.settings_changed.connect(self._update_preview) + self.left_layout.addWidget(self.time_panel) + self.left_layout.addWidget(self._build_save_group()) + self.left_layout.addStretch(1) + + self.right_layout.addWidget(self._build_inspection_group()) + self.right_layout.addWidget(self._build_command_group()) + self.right_layout.addStretch(1) + self.statusBar().showMessage("Ready") + + def _build_project_group(self) -> QGroupBox: + group = QGroupBox("Project") + form = QFormLayout(group) + project_row = QHBoxLayout() + self.project_dir_edit = QLineEdit() + self.project_dir_edit.editingFinished.connect( + self._on_project_dir_changed + ) + project_row.addWidget(self.project_dir_edit, stretch=1) + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self._browse_project_dir) + project_row.addWidget(browse_button) + project_widget = QWidget() + project_widget.setLayout(project_row) + form.addRow("Project folder", project_widget) + + self.run_file_edit = QLineEdit() + self.run_file_edit.setReadOnly(True) + form.addRow("Run file", self.run_file_edit) + return group + + def _build_input_group(self) -> QGroupBox: + group = QGroupBox("Input / Output") + form = QFormLayout(group) + self.frames_dir_edit = QLineEdit() + self.frames_dir_edit.editingFinished.connect(self._inspect_frames) + form.addRow( + "Frames folder", + self._make_path_row( + self.frames_dir_edit, + self._browse_frames_dir, + ), + ) + + self.energy_file_edit = QLineEdit() + self.energy_file_edit.editingFinished.connect(self._update_preview) + form.addRow( + "CP2K .ener file", + self._make_path_row( + self.energy_file_edit, + self._browse_energy_file, + ), + ) + + self.output_file_edit = QLineEdit() + self.output_file_edit.editingFinished.connect(self._update_preview) + form.addRow( + "Output dataset", + self._make_path_row( + self.output_file_edit, + self._browse_output_file, + ), + ) + return group + + def _build_save_group(self) -> QGroupBox: + group = QGroupBox("Save") + layout = QHBoxLayout(group) + inspect_button = QPushButton("Inspect Frames") + inspect_button.clicked.connect(self._inspect_frames) + layout.addWidget(inspect_button) + save_button = QPushButton("Save Run File") + save_button.clicked.connect(self._save_run_file) + layout.addWidget(save_button) + layout.addStretch(1) + return group + + def _build_inspection_group(self) -> QGroupBox: + group = QGroupBox("Inspection") + layout = QVBoxLayout(group) + self.inspection_box = QPlainTextEdit() + self.inspection_box.setReadOnly(True) + self.inspection_box.setMinimumHeight(210) + layout.addWidget(self.inspection_box) + return group + + def _build_command_group(self) -> QGroupBox: + group = QGroupBox("CLI Command / JSON") + layout = QVBoxLayout(group) + layout.addWidget(QLabel("Commands")) + self.command_box = QPlainTextEdit() + self.command_box.setReadOnly(True) + self.command_box.setMinimumHeight(150) + layout.addWidget(self.command_box) + layout.addWidget(QLabel("Run file preview")) + self.json_preview_box = QPlainTextEdit() + self.json_preview_box.setReadOnly(True) + self.json_preview_box.setMinimumHeight(300) + layout.addWidget(self.json_preview_box) + return group + + def _make_path_row(self, line_edit: QLineEdit, callback) -> QWidget: + widget = QWidget() + row = QHBoxLayout(widget) + row.setContentsMargins(0, 0, 0, 0) + row.addWidget(line_edit, stretch=1) + button = QPushButton("Browse...") + button.clicked.connect(callback) + row.addWidget(button) + return widget + + def _browse_project_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select project folder", + str(self._browse_start_dir), + ) + if not selected: + return + self.project_dir_edit.setText(selected) + self._on_project_dir_changed() + + def _browse_frames_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select extracted frames folder", + str(self._browse_start_dir), + ) + if not selected: + return + self.frames_dir_edit.setText(selected) + self._browse_start_dir = Path(selected).expanduser().resolve() + self._inspect_frames() + + def _browse_energy_file(self, *_args: object) -> None: + path, _selected_filter = QFileDialog.getOpenFileName( + self, + "Select CP2K .ener file", + self.energy_file_edit.text().strip() + or str(self._browse_start_dir), + "Energy Files (*.ener);;All Files (*)", + ) + if path: + self.energy_file_edit.setText(path) + self._update_preview() + + def _browse_output_file(self, *_args: object) -> None: + path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Select output dataset", + self.output_file_edit.text().strip() + or str(self._browse_start_dir), + "JSON Files (*.json);;All Files (*)", + ) + if path: + self.output_file_edit.setText(path) + self._update_preview() + + def _on_project_dir_changed(self, *_args: object) -> None: + project_dir = self._project_dir() + if project_dir is None: + return + self._browse_start_dir = project_dir + self._refresh_run_file_path() + defaults = self._project_defaults(project_dir) + if not self.frames_dir_edit.text().strip(): + frames_dir = defaults.get("frames_dir") + if frames_dir is not None and frames_dir.is_dir(): + self.frames_dir_edit.setText(str(frames_dir)) + if not self.energy_file_edit.text().strip(): + energy_file = defaults.get("energy_file") + if energy_file is not None and energy_file.is_file(): + self.energy_file_edit.setText(str(energy_file)) + if self.prediction_panel.clusters_dir() is None: + self.prediction_panel.set_clusters_dir( + defaults.get("clusters_dir"), + emit_signal=False, + ) + if self.prediction_panel.experimental_data_file() is None: + self.prediction_panel.set_experimental_data_file( + defaults.get("experimental_data_file"), + emit_signal=False, + ) + self._inspect_frames() + + def _inspect_frames(self, *_args: object) -> None: + frames_text = self.frames_dir_edit.text().strip() + if not frames_text: + self._last_summary = None + self.inspection_box.setPlainText("No frames folder selected.") + self._update_preview() + return + try: + workflow = ClusterWorkflow( + frames_dir=frames_text, + atom_type_definitions={}, + pair_cutoff_definitions={}, + ) + summary = workflow.inspect() + except Exception as exc: + self._last_summary = None + self.inspection_box.setPlainText(str(exc)) + self.statusBar().showMessage("Frames inspection failed") + self._update_preview() + return + self._last_summary = summary + frame_format = str(summary.get("frame_format", "") or "") + self.definitions_panel.set_frame_mode(frame_format) + self.inspection_box.setPlainText(_summary_text(summary)) + self._refresh_suggested_output_file() + self.statusBar().showMessage( + f"Discovered {int(summary.get('n_frames', 0))} frame(s)" + ) + self._update_preview() + + def _refresh_run_file_path(self) -> None: + project_dir = self._project_dir() + if project_dir is None: + self.run_file_edit.clear() + return + self.run_file_edit.setText( + str(default_clusterdynamicsml_run_file_path(project_dir)) + ) + + def _refresh_suggested_output_file(self) -> None: + project_dir = self._project_dir() + frames_text = self.frames_dir_edit.text().strip() + if project_dir is None or not frames_text: + return + try: + suggested = suggest_clusterdynamicsml_output_file( + project_dir=project_dir, + frames_dir=frames_text, + ) + except Exception: + return + current = self.output_file_edit.text().strip() + if not current or current == self._last_suggested_output_file: + self.output_file_edit.setText(str(suggested)) + self._last_suggested_output_file = str(suggested) + + def _save_run_file(self, *_args: object) -> None: + try: + project_dir = self._require_project_dir() + config = self._current_config(project_dir) + except Exception as exc: + QMessageBox.warning( + self, "Cluster Dynamics ML CLI Setup", str(exc) + ) + return + run_file_path = default_clusterdynamicsml_run_file_path(project_dir) + save_clusterdynamicsml_run_config(run_file_path, config) + self.run_file_edit.setText(str(run_file_path)) + self.json_preview_box.setPlainText(save_preview_text(config.to_dict())) + self._update_preview() + self.statusBar().showMessage(f"Saved run file: {run_file_path}") + QMessageBox.information( + self, + "Cluster Dynamics ML CLI Setup", + f"Saved cluster dynamics ML CLI run file:\n{run_file_path}", + ) + + def _update_preview(self, *_args: object) -> None: + project_dir = self._project_dir() + if project_dir is None: + self.command_box.setPlainText( + "Select a project folder before saving the CLI run file." + ) + self.json_preview_box.clear() + return + self._refresh_run_file_path() + self.command_box.setPlainText( + f'clusterdynamicsml run "{project_dir}"\n' + f'saxshell clusterdynamicsml run "{project_dir}"' + ) + try: + config = self._current_config(project_dir) + payload = config.to_dict() + try: + payload["selection_preview"] = ( + preview_clusterdynamicsml_run_config( + project_dir=project_dir, + config=config, + ) + ) + except Exception as exc: + payload["selection_preview_error"] = str(exc) + except Exception as exc: + self.json_preview_box.setPlainText(str(exc)) + return + self.json_preview_box.setPlainText(save_preview_text(payload)) + + def _current_config(self, project_dir: Path): + frames_text = self.frames_dir_edit.text().strip() + if not frames_text: + raise ValueError("Choose a frames folder before saving.") + output_text = self.output_file_edit.text().strip() + energy_text = self.energy_file_edit.text().strip() + frame_format = "" + if self._last_summary is not None: + frame_format = str( + self._last_summary.get("frame_format", "") or "" + ) + shell_references = ( + self.definitions_panel.shell_reference_definitions() + if frame_format == "pdb" + else () + ) + return build_clusterdynamicsml_run_config( + project_dir=project_dir, + frames_dir=frames_text, + output_file=output_text or None, + clusters_dir=self.prediction_panel.clusters_dir(), + experimental_data_file=( + self.prediction_panel.experimental_data_file() + ), + energy_file=energy_text or None, + atom_type_definitions=self.definitions_panel.atom_type_definitions(), + pair_cutoff_definitions=( + self.definitions_panel.pair_cutoff_definitions() + ), + box_dimensions=self.definitions_panel.box_dimensions(), + use_pbc=self.definitions_panel.use_pbc(), + default_cutoff=self.definitions_panel.default_cutoff(), + shell_levels=self.definitions_panel.shell_growth_levels(), + shared_shells=self.definitions_panel.shared_shells(), + include_shell_atoms_in_stoichiometry=( + self.definitions_panel.include_shell_atoms_in_stoichiometry() + ), + search_mode=self.definitions_panel.search_mode(), + shell_reference_definitions=shell_references, + folder_start_time_fs=self.time_panel.folder_start_time_fs(), + first_frame_time_fs=self.time_panel.first_frame_time_fs(), + frame_timestep_fs=self.time_panel.frame_timestep_fs(), + frames_per_colormap_timestep=( + self.time_panel.frames_per_colormap_timestep() + ), + analysis_start_fs=self.time_panel.analysis_start_fs(), + analysis_stop_fs=self.time_panel.analysis_stop_fs(), + target_node_counts=self.prediction_panel.target_node_counts(), + candidates_per_size=self.prediction_panel.candidates_per_size(), + prediction_population_share_threshold=( + self.prediction_panel.prediction_population_share_threshold() + ), + q_min=self.prediction_panel.q_min(), + q_max=self.prediction_panel.q_max(), + q_points=self.prediction_panel.q_points(), + ) + + def _project_dir(self) -> Path | None: + text = self.project_dir_edit.text().strip() + if not text: + return None + return Path(text).expanduser().resolve() + + def _require_project_dir(self) -> Path: + project_dir = self._project_dir() + if project_dir is None: + raise ValueError("Choose a project folder before saving.") + if not project_dir.is_dir(): + raise ValueError(f"Project folder does not exist: {project_dir}") + return project_dir + + def _load_shell_reference_library_entries(self) -> None: + try: + entries = list(list_reference_library()) + except Exception: + entries = [] + self.definitions_panel.set_shell_reference_library_entries( + entries, + emit_signal=False, + ) + + @staticmethod + def _project_defaults(project_dir: Path) -> dict[str, Path | None]: + defaults: dict[str, Path | None] = { + "frames_dir": None, + "energy_file": None, + "clusters_dir": None, + "experimental_data_file": None, + } + try: + settings = SAXSProjectManager().load_project(project_dir) + except Exception: + return defaults + defaults["frames_dir"] = settings.resolved_frames_dir + defaults["energy_file"] = settings.resolved_energy_file + defaults["clusters_dir"] = settings.resolved_clusters_dir + defaults["experimental_data_file"] = ( + settings.resolved_experimental_data_path + ) + return defaults + + +def _summary_text(summary: dict[str, object]) -> str: + box_dimensions = summary.get("box_dimensions") + if box_dimensions is None: + box_dimensions = summary.get("estimated_box_dimensions") + source_kind = summary.get("box_dimensions_source_kind") + label = ( + "Source box dimensions" + if source_kind == "source_filename" + else "Estimated box dimensions" + ) + lines = [ + f"Frames folder: {summary.get('input_dir')}", + f"Mode: {summary.get('mode_label')}", + f"Frames: {summary.get('n_frames')}", + f"{label}: {format_box_dimensions(box_dimensions)}", + ] + if summary.get("box_dimensions_source") is not None: + lines.append(f"Box source: {summary.get('box_dimensions_source')}") + return "\n".join(lines) + + +def save_preview_text(payload: dict[str, object]) -> str: + return json.dumps(payload, indent=2) + + +def _optional_resolved_path(value: str | Path | None) -> Path | None: + if value is None: + return None + return Path(value).expanduser().resolve() + + +def launch_clusterdynamicsml_run_file_ui( + *, + initial_project_dir: str | Path | None = None, + initial_frames_dir: str | Path | None = None, + initial_energy_file: str | Path | None = None, + initial_clusters_dir: str | Path | None = None, + initial_experimental_data_file: str | Path | None = None, +) -> ClusterDynamicsMLRunFileWindow: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) + window = ClusterDynamicsMLRunFileWindow( + initial_project_dir=initial_project_dir, + initial_frames_dir=initial_frames_dir, + initial_energy_file=initial_energy_file, + initial_clusters_dir=initial_clusters_dir, + initial_experimental_data_file=initial_experimental_data_file, + ) + window.show() + window.raise_() + return window + + +__all__ = [ + "ClusterDynamicsMLRunFileWindow", + "launch_clusterdynamicsml_run_file_ui", +] diff --git a/src/saxshell/fullrmc/__init__.py b/src/saxshell/fullrmc/__init__.py index 42d8072..df4f817 100644 --- a/src/saxshell/fullrmc/__init__.py +++ b/src/saxshell/fullrmc/__init__.py @@ -19,6 +19,9 @@ PackmolPlanningEntry, PackmolPlanningMetadata, PackmolPlanningSettings, + PackmolSupplementalAllocation, + PackmolSupplementalAllocationEntry, + PackmolSupplementalComponentSettings, build_packmol_plan, load_packmol_planning_metadata, save_packmol_planning_metadata, @@ -27,6 +30,7 @@ PackmolSetupEntry, PackmolSetupMetadata, PackmolSetupSettings, + PackmolSetupSupplementalEntry, build_packmol_setup, load_packmol_setup_metadata, save_packmol_setup_metadata, @@ -146,6 +150,10 @@ "PackmolSetupEntry", "PackmolSetupMetadata", "PackmolSetupSettings", + "PackmolSetupSupplementalEntry", + "PackmolSupplementalAllocation", + "PackmolSupplementalAllocationEntry", + "PackmolSupplementalComponentSettings", "RMCSetupMainWindow", "RepresentativePreviewCluster", "RepresentativePreviewSeries", diff --git a/src/saxshell/fullrmc/packmol_planning.py b/src/saxshell/fullrmc/packmol_planning.py index db9f91e..3500f73 100644 --- a/src/saxshell/fullrmc/packmol_planning.py +++ b/src/saxshell/fullrmc/packmol_planning.py @@ -2,6 +2,8 @@ import csv import json +import re +from collections import Counter from dataclasses import asdict, dataclass from datetime import datetime from pathlib import Path @@ -31,14 +33,49 @@ from .project_loader import RMCDreamProjectSource +@dataclass(slots=True) +class PackmolSupplementalComponentSettings: + role: str = "solute" + reference: str | None = None + element: str | None = None + residue_name: str = "" + name: str = "" + + def to_dict(self) -> dict[str, object]: + return asdict(self) + + @classmethod + def from_dict( + cls, + payload: dict[str, object] | None, + ) -> "PackmolSupplementalComponentSettings": + source = dict(payload or {}) + return cls( + role=_normalized_component_role(source.get("role")), + reference=_optional_text(source.get("reference")), + element=_optional_text(source.get("element")), + residue_name=_normalized_optional_residue_name( + source.get("residue_name") + ), + name=str(source.get("name", "") or "").strip(), + ) + + @dataclass(slots=True) class PackmolPlanningSettings: planning_mode: str = "per_element" box_side_length_a: float = 100.0 free_solvent_reference: str | None = None + supplemental_components: tuple[ + PackmolSupplementalComponentSettings, ... + ] = () def to_dict(self) -> dict[str, object]: - return asdict(self) + payload = asdict(self) + payload["supplemental_components"] = [ + component.to_dict() for component in self.supplemental_components + ] + return payload @classmethod def from_dict( @@ -59,6 +96,11 @@ def from_dict( free_solvent_reference=_optional_text( source.get("free_solvent_reference") ), + supplemental_components=tuple( + PackmolSupplementalComponentSettings.from_dict(entry) + for entry in source.get("supplemental_components", []) + if isinstance(entry, dict) + ), ) @@ -141,6 +183,142 @@ def from_dict( ) +@dataclass(slots=True) +class PackmolSupplementalAllocationEntry: + role: str + name: str + source_type: str + reference_name: str | None + reference_path: str | None + residue_name: str + planned_count: int + atom_count: int + element_counts: dict[str, int] + + def to_dict(self) -> dict[str, object]: + return asdict(self) + + @classmethod + def from_dict( + cls, + payload: dict[str, object], + ) -> "PackmolSupplementalAllocationEntry": + return cls( + role=_normalized_component_role(payload.get("role")), + name=str(payload.get("name", "") or "").strip(), + source_type=str(payload.get("source_type", "") or "").strip(), + reference_name=_optional_text(payload.get("reference_name")), + reference_path=_optional_text(payload.get("reference_path")), + residue_name=_normalized_residue_name( + str(payload.get("residue_name", "") or "") + ), + planned_count=int(payload.get("planned_count", 0)), + atom_count=int(payload.get("atom_count", 0)), + element_counts={ + _normalized_element_symbol(key): int(value) + for key, value in dict( + payload.get("element_counts", {}) + ).items() + if int(value) > 0 + }, + ) + + +@dataclass(slots=True) +class PackmolSupplementalAllocation: + target_solute_formula_units: int + formula_unit_basis: dict[str, float] + cluster_solute_element_totals: dict[str, int] + target_solute_element_totals: dict[str, int] + missing_solute_element_totals: dict[str, int] + added_solute_element_totals: dict[str, int] + unfilled_solute_element_totals: dict[str, int] + entries: list[PackmolSupplementalAllocationEntry] + warnings: tuple[str, ...] = () + + def to_dict(self) -> dict[str, object]: + return { + "target_solute_formula_units": self.target_solute_formula_units, + "formula_unit_basis": dict(self.formula_unit_basis), + "cluster_solute_element_totals": dict( + self.cluster_solute_element_totals + ), + "target_solute_element_totals": dict( + self.target_solute_element_totals + ), + "missing_solute_element_totals": dict( + self.missing_solute_element_totals + ), + "added_solute_element_totals": dict( + self.added_solute_element_totals + ), + "unfilled_solute_element_totals": dict( + self.unfilled_solute_element_totals + ), + "entries": [entry.to_dict() for entry in self.entries], + "warnings": list(self.warnings), + } + + @classmethod + def from_dict( + cls, + payload: dict[str, object] | None, + ) -> "PackmolSupplementalAllocation | None": + if not payload: + return None + return cls( + target_solute_formula_units=int( + payload.get("target_solute_formula_units", 0) + ), + formula_unit_basis={ + str(key): float(value) + for key, value in dict( + payload.get("formula_unit_basis", {}) + ).items() + }, + cluster_solute_element_totals={ + str(key): int(value) + for key, value in dict( + payload.get("cluster_solute_element_totals", {}) + ).items() + }, + target_solute_element_totals={ + str(key): int(value) + for key, value in dict( + payload.get("target_solute_element_totals", {}) + ).items() + }, + missing_solute_element_totals={ + str(key): int(value) + for key, value in dict( + payload.get("missing_solute_element_totals", {}) + ).items() + }, + added_solute_element_totals={ + str(key): int(value) + for key, value in dict( + payload.get("added_solute_element_totals", {}) + ).items() + }, + unfilled_solute_element_totals={ + str(key): int(value) + for key, value in dict( + payload.get("unfilled_solute_element_totals", {}) + ).items() + }, + entries=[ + PackmolSupplementalAllocationEntry.from_dict(dict(entry)) + for entry in payload.get("entries", []) + if isinstance(entry, dict) + ], + warnings=tuple( + str(value) + for value in payload.get("warnings", []) + if str(value).strip() + ), + ) + + @dataclass(slots=True) class PackmolPlanningEntry: structure: str @@ -196,6 +374,7 @@ class PackmolPlanningMetadata: achieved_total_number_density_a3: float achieved_element_number_density_a3: dict[str, float] solvent_allocation: PackmolSolventAllocation | None + supplemental_allocation: PackmolSupplementalAllocation | None entries: list[PackmolPlanningEntry] report_text: str @@ -225,6 +404,11 @@ def to_dict(self) -> dict[str, object]: if self.solvent_allocation is None else self.solvent_allocation.to_dict() ), + "supplemental_allocation": ( + None + if self.supplemental_allocation is None + else self.supplemental_allocation.to_dict() + ), "entries": [entry.to_dict() for entry in self.entries], "report_text": self.report_text, } @@ -272,6 +456,11 @@ def from_dict( if isinstance(payload.get("solvent_allocation"), dict) else None ), + supplemental_allocation=PackmolSupplementalAllocation.from_dict( + payload.get("supplemental_allocation") + if isinstance(payload.get("supplemental_allocation"), dict) + else None + ), entries=[ PackmolPlanningEntry.from_dict(dict(entry)) for entry in payload.get("entries", []) @@ -309,6 +498,22 @@ def summary_text(self) -> str: ), ] ) + if self.supplemental_allocation is not None: + supplemental_count = sum( + entry.planned_count + for entry in self.supplemental_allocation.entries + ) + if supplemental_count > 0: + lines.append( + f"Supplemental solute components: {supplemental_count}" + ) + if self.supplemental_allocation.unfilled_solute_element_totals: + lines.append( + "Unfilled supplemental solute: " + + _format_element_counts( + self.supplemental_allocation.unfilled_solute_element_totals + ) + ) lines.extend( [ ( @@ -429,15 +634,6 @@ def build_packmol_plan( planned_count_weights = _normalized_weights(counts.astype(float)) planned_atom_weights = _normalized_weights(counts * atom_vector) - volume_a3 = settings.box_side_length_a**3 - achieved_element_nd = { - element: float( - np.dot(counts, element_matrix[element_index]) / volume_a3 - ) - for element_index, element in enumerate(ordered_elements) - } - achieved_total_nd = float(np.dot(counts, atom_vector) / volume_a3) - representative_entries = active_representatives.representative_entries entries: list[PackmolPlanningEntry] = [] for index, representative_entry in enumerate(representative_entries): @@ -463,6 +659,30 @@ def build_packmol_plan( ) ) + volume_a3 = settings.box_side_length_a**3 + achieved_element_counts = _planned_element_totals(entries) + supplemental_allocation = _build_supplemental_allocation( + settings=settings, + solution=solution, + box_targets=box_targets, + planning_entries=entries, + ) + if supplemental_allocation is not None: + for ( + element, + count, + ) in supplemental_allocation.added_solute_element_totals.items(): + achieved_element_counts[element] = achieved_element_counts.get( + element, 0 + ) + int(count) + achieved_element_nd = { + element: float(count / volume_a3) + for element, count in sorted(achieved_element_counts.items()) + } + achieved_total_nd = float( + sum(achieved_element_counts.values()) / volume_a3 + ) + solvent_allocation = _build_solvent_allocation( settings=settings, box_targets=box_targets, @@ -480,6 +700,7 @@ def build_packmol_plan( achieved_total_nd=achieved_total_nd, achieved_element_nd=achieved_element_nd, solvent_allocation=solvent_allocation, + supplemental_allocation=supplemental_allocation, ) metadata = PackmolPlanningMetadata( settings=settings, @@ -491,6 +712,7 @@ def build_packmol_plan( achieved_total_number_density_a3=achieved_total_nd, achieved_element_number_density_a3=achieved_element_nd, solvent_allocation=solvent_allocation, + supplemental_allocation=supplemental_allocation, entries=entries, report_text=report_text, ) @@ -688,6 +910,237 @@ def _vector_error(target: np.ndarray, achieved: np.ndarray) -> float: return float(np.linalg.norm(target - achieved, ord=2)) +def _planned_element_totals( + entries: list[PackmolPlanningEntry], +) -> dict[str, int]: + totals: dict[str, int] = {} + for entry in entries: + if entry.planned_count <= 0: + continue + for element, count in entry.element_counts.items(): + totals[element] = totals.get(element, 0) + ( + int(entry.planned_count) * int(count) + ) + return dict(sorted(totals.items())) + + +def _build_supplemental_allocation( + *, + settings: PackmolPlanningSettings, + solution: SolutionProperties, + box_targets: dict[str, object], + planning_entries: list[PackmolPlanningEntry], +) -> PackmolSupplementalAllocation | None: + solute_formula = { + _normalized_element_symbol(element): int(count) + for element, count in solution.solute_dict.items() + if int(count) > 0 + } + cluster_totals = _planned_element_totals(planning_entries) + if not solute_formula: + return None + + formula_units, formula_basis = _estimate_solute_formula_units( + solute_formula=solute_formula, + cluster_totals=cluster_totals, + box_targets=box_targets, + ) + target_totals = { + element: int(formula_units) * int(count) + for element, count in sorted(solute_formula.items()) + } + missing_totals = { + element: max( + 0, int(target_count) - int(cluster_totals.get(element, 0)) + ) + for element, target_count in sorted(target_totals.items()) + if max(0, int(target_count) - int(cluster_totals.get(element, 0))) > 0 + } + if not missing_totals and not settings.supplemental_components: + return None + + warnings: list[str] = [] + resolved_components = [ + _resolve_supplemental_component(component) + for component in settings.supplemental_components + ] + remaining = Counter(missing_totals) + added_totals: Counter[str] = Counter() + entries: list[PackmolSupplementalAllocationEntry] = [] + for component in resolved_components: + planned_count = 0 + if component.role == "solute": + planned_count = _supplemental_component_count( + component, + remaining, + ) + if planned_count > 0: + for element, count in component.element_counts.items(): + total = int(planned_count) * int(count) + remaining[element] -= total + added_totals[element] += total + else: + warnings.append( + f"{component.name} is marked as solvent, so it was not used " + "to satisfy missing solute stoichiometry." + ) + entries.append( + PackmolSupplementalAllocationEntry( + role=component.role, + name=component.name, + source_type=component.source_type, + reference_name=component.reference_name, + reference_path=component.reference_path, + residue_name=component.residue_name, + planned_count=int(planned_count), + atom_count=component.atom_count, + element_counts=dict(component.element_counts), + ) + ) + + unfilled = { + element: int(count) + for element, count in sorted(remaining.items()) + if int(count) > 0 + } + unrepresented_unfilled = { + element: count + for element, count in unfilled.items() + if int(cluster_totals.get(element, 0)) <= 0 + } + if unrepresented_unfilled: + raise ValueError( + "Supplemental solute components are required to supply missing " + "solute stoichiometry not present in the weighted cluster " + "structures: " + + _format_element_counts(unrepresented_unfilled) + + ". Add reference-molecule or single-atom solute components " + "before computing the Packmol plan." + ) + if unfilled: + warnings.append( + "Some solute stoichiometry remains unfilled after supplemental " + "component allocation: " + _format_element_counts(unfilled) + ) + + return PackmolSupplementalAllocation( + target_solute_formula_units=int(formula_units), + formula_unit_basis=formula_basis, + cluster_solute_element_totals=cluster_totals, + target_solute_element_totals=target_totals, + missing_solute_element_totals=missing_totals, + added_solute_element_totals=dict(sorted(added_totals.items())), + unfilled_solute_element_totals=unfilled, + entries=entries, + warnings=tuple(warnings), + ) + + +@dataclass(slots=True) +class _ResolvedSupplementalComponent: + role: str + name: str + source_type: str + reference_name: str | None + reference_path: str | None + residue_name: str + atom_count: int + element_counts: dict[str, int] + + +def _estimate_solute_formula_units( + *, + solute_formula: dict[str, int], + cluster_totals: dict[str, int], + box_targets: dict[str, object], +) -> tuple[int, dict[str, float]]: + ratios = { + element: float(cluster_totals[element]) / float(count) + for element, count in solute_formula.items() + if count > 0 and cluster_totals.get(element, 0) > 0 + } + if ratios: + ratio_values = np.asarray(list(ratios.values()), dtype=float) + return max(0, int(round(float(np.median(ratio_values))))), dict( + sorted(ratios.items()) + ) + return ( + max( + 0, + int(round(float(box_targets.get("solute_molecules", 0) or 0))), + ), + {}, + ) + + +def _resolve_supplemental_component( + settings: PackmolSupplementalComponentSettings, +) -> _ResolvedSupplementalComponent: + role = _normalized_component_role(settings.role) + reference_identifier = _optional_text(settings.reference) + element_identifier = _optional_text(settings.element) + if reference_identifier is not None: + reference_path = ( + resolve_reference_path(reference_identifier).expanduser().resolve() + ) + structure = PDBStructure.from_file(reference_path) + counts = _count_elements(structure) + residue_name = _normalized_residue_name( + settings.residue_name + or (structure.atoms[0].residue_name if structure.atoms else "") + or reference_path.stem + ) + name = settings.name.strip() or reference_path.stem + return _ResolvedSupplementalComponent( + role=role, + name=name, + source_type="reference", + reference_name=reference_path.stem, + reference_path=str(reference_path), + residue_name=residue_name, + atom_count=len(structure.atoms), + element_counts=counts, + ) + + if element_identifier is not None: + element = _normalized_element_symbol(element_identifier) + residue_name = _normalized_residue_name( + settings.residue_name or element + ) + name = settings.name.strip() or element + return _ResolvedSupplementalComponent( + role=role, + name=name, + source_type="single_atom", + reference_name=None, + reference_path=None, + residue_name=residue_name, + atom_count=1, + element_counts={element: 1}, + ) + + raise ValueError( + "Supplemental Packmol components must define either a reference " + "molecule or a single atom element." + ) + + +def _supplemental_component_count( + component: _ResolvedSupplementalComponent, + remaining: Counter[str], +) -> int: + if not component.element_counts: + return 0 + limiting_counts = [ + int(remaining.get(element, 0)) // int(count) + for element, count in component.element_counts.items() + if int(count) > 0 + ] + if not limiting_counts: + return 0 + return max(0, min(limiting_counts)) + + def _build_plan_report( *, settings: PackmolPlanningSettings, @@ -698,6 +1151,7 @@ def _build_plan_report( achieved_total_nd: float, achieved_element_nd: dict[str, float], solvent_allocation: PackmolSolventAllocation | None, + supplemental_allocation: PackmolSupplementalAllocation | None, ) -> str: lines = [ "== Packmol Planning ==", @@ -724,6 +1178,19 @@ def _build_plan_report( ), ] ) + if supplemental_allocation is not None: + supplemental_count = sum( + entry.planned_count for entry in supplemental_allocation.entries + ) + lines.extend( + [ + ( + "Supplemental solute formula units: " + f"{supplemental_allocation.target_solute_formula_units}" + ), + ("Supplemental solute components: " f"{supplemental_count}"), + ] + ) lines.extend( [ ( @@ -770,6 +1237,50 @@ def _build_plan_report( f" - {element}: target={target_element_nd[element]:.6f}, " f"achieved={achieved_element_nd.get(element, 0.0):.6f}" ) + if supplemental_allocation is not None: + lines.extend(["", "Supplemental solute accounting:"]) + lines.append( + " Cluster solute totals: " + + _format_element_counts( + supplemental_allocation.cluster_solute_element_totals + ) + ) + lines.append( + " Target solute totals: " + + _format_element_counts( + supplemental_allocation.target_solute_element_totals + ) + ) + lines.append( + " Missing before supplemental components: " + + _format_element_counts( + supplemental_allocation.missing_solute_element_totals + ) + ) + lines.append( + " Added by supplemental components: " + + _format_element_counts( + supplemental_allocation.added_solute_element_totals + ) + ) + if supplemental_allocation.unfilled_solute_element_totals: + lines.append( + " Unfilled after supplemental components: " + + _format_element_counts( + supplemental_allocation.unfilled_solute_element_totals + ) + ) + if supplemental_allocation.entries: + lines.append(" Components:") + for entry in supplemental_allocation.entries: + lines.append( + " - " + f"{entry.name}: {entry.planned_count} x " + f"{_format_element_counts(entry.element_counts)} " + f"({entry.role}, residue {entry.residue_name})" + ) + for warning in supplemental_allocation.warnings: + lines.append(f" Warning: {warning}") return "\n".join(lines) @@ -1110,6 +1621,40 @@ def _normalized_residue_name(text: str) -> str: return (collapsed or "CLU")[:3] +def _normalized_optional_residue_name(value: object) -> str: + text = str(value or "").strip() + if not text: + return "" + return _normalized_residue_name(text) + + +def _normalized_component_role(value: object) -> str: + text = str(value or "solute").strip().lower() + return text if text in {"solute", "solvent"} else "solute" + + +def _normalized_element_symbol(value: object) -> str: + text = re.sub(r"[^A-Za-z]", "", str(value or "")).strip() + if not text: + raise ValueError("Element symbols must contain at least one letter.") + if len(text) == 1: + return text.upper() + return text[0].upper() + text[1:].lower() + + +def _format_element_counts(counts: dict[str, int]) -> str: + if not counts: + return "none" + return ( + ", ".join( + f"{element} x{int(count)}" + for element, count in sorted(counts.items()) + if int(count) != 0 + ) + or "none" + ) + + def _optional_text(value: object) -> str | None: if value is None: return None @@ -1123,6 +1668,9 @@ def _optional_text(value: object) -> str | None: "PackmolPlanningSettings", "PackmolSolventAllocation", "PackmolSolventAllocationEntry", + "PackmolSupplementalAllocation", + "PackmolSupplementalAllocationEntry", + "PackmolSupplementalComponentSettings", "build_packmol_plan", "load_packmol_planning_metadata", "save_packmol_planning_metadata", diff --git a/src/saxshell/fullrmc/packmol_setup.py b/src/saxshell/fullrmc/packmol_setup.py index 41cc50e..7044b9c 100644 --- a/src/saxshell/fullrmc/packmol_setup.py +++ b/src/saxshell/fullrmc/packmol_setup.py @@ -21,6 +21,7 @@ resolved_representative_structure_mode, ) from saxshell.saxs.debye import load_structure_file +from saxshell.saxs.stoichiometry import parse_stoich_label from saxshell.structure import PDBAtom, PDBStructure from saxshell.xyz2pdb import resolve_reference_path @@ -28,6 +29,9 @@ from .project_loader import RMCDreamProjectSource +_PACKMOL_SOLUTE_MATCH_TOLERANCE_A = 0.05 + + @dataclass(slots=True) class PackmolSetupSettings: tolerance_angstrom: float = 2.0 @@ -87,6 +91,10 @@ class PackmolSetupEntry: source_pdb: str packmol_pdb: str atom_count: int + solute_atom_count: int = 0 + solvent_atom_count: int = 0 + solvent_residue_count: int = 0 + solvent_residue_names: tuple[str, ...] = () def to_dict(self) -> dict[str, object]: return asdict(self) @@ -110,6 +118,56 @@ def from_dict( source_pdb=str(payload.get("source_pdb", "")).strip(), packmol_pdb=str(payload.get("packmol_pdb", "")).strip(), atom_count=int(payload.get("atom_count", 0)), + solute_atom_count=int(payload.get("solute_atom_count", 0)), + solvent_atom_count=int(payload.get("solvent_atom_count", 0)), + solvent_residue_count=int(payload.get("solvent_residue_count", 0)), + solvent_residue_names=tuple( + str(value).strip() + for value in payload.get("solvent_residue_names", []) + if str(value).strip() + ), + ) + + +@dataclass(slots=True) +class PackmolSetupSupplementalEntry: + role: str + name: str + source_type: str + reference_name: str | None + reference_path: str | None + residue_name: str + planned_count: int + atom_count: int + element_counts: dict[str, int] + packmol_pdb: str + + def to_dict(self) -> dict[str, object]: + return asdict(self) + + @classmethod + def from_dict( + cls, + payload: dict[str, object], + ) -> "PackmolSetupSupplementalEntry": + return cls( + role=str(payload.get("role", "solute") or "solute").strip(), + name=str(payload.get("name", "") or "").strip(), + source_type=str(payload.get("source_type", "") or "").strip(), + reference_name=_optional_text(payload.get("reference_name")), + reference_path=_optional_text(payload.get("reference_path")), + residue_name=_normalized_residue_name( + str(payload.get("residue_name", "") or "") + ), + planned_count=int(payload.get("planned_count", 0)), + atom_count=int(payload.get("atom_count", 0)), + element_counts={ + str(key): int(value) + for key, value in dict( + payload.get("element_counts", {}) + ).items() + }, + packmol_pdb=str(payload.get("packmol_pdb", "") or "").strip(), ) @@ -130,7 +188,9 @@ class PackmolSetupMetadata: solvent_molecules_in_clusters: int free_solvent_molecules: int audit_report_path: str + build_report_path: str entries: list[PackmolSetupEntry] + supplemental_entries: list[PackmolSetupSupplementalEntry] def to_dict(self) -> dict[str, object]: return { @@ -156,7 +216,11 @@ def to_dict(self) -> dict[str, object]: ), "free_solvent_molecules": self.free_solvent_molecules, "audit_report_path": self.audit_report_path, + "build_report_path": self.build_report_path, "entries": [entry.to_dict() for entry in self.entries], + "supplemental_entries": [ + entry.to_dict() for entry in self.supplemental_entries + ], } @classmethod @@ -206,11 +270,22 @@ def from_dict( audit_report_path=str( payload.get("audit_report_path", "") ).strip(), + build_report_path=str( + payload.get( + "build_report_path", + payload.get("audit_report_path", ""), + ) + ).strip(), entries=[ PackmolSetupEntry.from_dict(dict(entry)) for entry in payload.get("entries", []) if isinstance(entry, dict) ], + supplemental_entries=[ + PackmolSetupSupplementalEntry.from_dict(dict(entry)) + for entry in payload.get("supplemental_entries", []) + if isinstance(entry, dict) + ], ) def summary_text(self) -> str: @@ -225,8 +300,16 @@ def summary_text(self) -> str: f"Box side: {self.box_side_length_a:.3f} A", f"Packmol tolerance: {self.settings.tolerance_angstrom:.3f} A", f"Packmol input: {Path(self.packmol_input_path).name}", + f"Build report: {Path(self.build_report_path).name}", f"Representative PDBs copied: {len(self.entries)}", ] + supplemental_count = sum( + entry.planned_count for entry in self.supplemental_entries + ) + if supplemental_count > 0: + lines.append( + f"Supplemental solute components: {supplemental_count}" + ) if self.free_solvent_reference_name: lines.append( "Free solvent structure: " @@ -259,6 +342,15 @@ def summary_text(self) -> str: return "\n".join(lines) +@dataclass(slots=True) +class _PreparedPackmolStructure: + structure: PDBStructure + solute_atom_count: int + solvent_atom_count: int + solvent_residue_count: int + solvent_residue_names: tuple[str, ...] + + def build_packmol_setup( project_source: "RMCDreamProjectSource", settings: PackmolSetupSettings | None = None, @@ -314,6 +406,10 @@ def build_packmol_setup( (entry.structure, entry.motif, entry.param): entry for entry in active_solvent.entries } + known_solvent_residue_names = _solvent_residue_names_for_packmol_source( + active_solvent, + free_solvent_reference_path=free_solvent_reference_path, + ) representative_structure_mode = resolved_representative_structure_mode( active_representatives, active_solvent, @@ -331,9 +427,10 @@ def build_packmol_setup( "Packmol planning referenced a cluster bin without a representative: " f"{plan_entry.structure}/{plan_entry.motif}" ) + solvent_entry = solvent_lookup.get(key) source_structure, source_pdb_path = _resolve_structure_for_packmol( representative_entry, - solvent_lookup.get(key), + solvent_entry, representative_structure_mode=representative_structure_mode, use_completed=active_settings.use_completed_representatives, ) @@ -350,8 +447,18 @@ def build_packmol_setup( prepared_structure = _prepare_packmol_structure( source_structure, residue_name=residue_name, + solvent_residue_names=known_solvent_residue_names, + solute_reference_structure=( + _solute_reference_structure_for_packmol_source(solvent_entry) + ), + expected_solute_element_counts=parse_stoich_label( + plan_entry.structure + ), + solute_atom_count=_solute_atom_count_for_packmol_source( + solvent_entry + ), ) - prepared_structure.write_pdb_file(packmol_path) + prepared_structure.structure.write_pdb_file(packmol_path) entries.append( PackmolSetupEntry( structure=plan_entry.structure, @@ -364,7 +471,15 @@ def build_packmol_setup( residue_name=residue_name, source_pdb=str(source_pdb_path), packmol_pdb=str(packmol_path), - atom_count=len(prepared_structure.atoms), + atom_count=len(prepared_structure.structure.atoms), + solute_atom_count=prepared_structure.solute_atom_count, + solvent_atom_count=prepared_structure.solvent_atom_count, + solvent_residue_count=( + prepared_structure.solvent_residue_count + ), + solvent_residue_names=( + prepared_structure.solvent_residue_names + ), ) ) @@ -407,9 +522,15 @@ def build_packmol_setup( shutil.copy2(source_solvent, destination) solvent_pdb_path = str(destination) + supplemental_entries = _write_supplemental_packmol_structures( + project_source.rmcsetup_paths.packmol_inputs_dir, + active_plan, + ) + input_path = _write_packmol_input( project_source.rmcsetup_paths.packmol_inputs_dir, entries, + supplemental_entries=supplemental_entries, solvent_pdb_path=solvent_pdb_path, free_solvent_molecules=free_solvent_molecules, box_side_length_a=box_side_length_a, @@ -426,6 +547,23 @@ def build_packmol_setup( target_solvent_molecules=target_solvent_molecules, solvent_molecules_in_clusters=solvent_molecules_in_clusters, free_solvent_molecules=free_solvent_molecules, + supplemental_entries=supplemental_entries, + ) + build_report_path = _write_packmol_build_report( + project_source, + active_plan, + entries, + input_path=input_path, + solvent_pdb_path=solvent_pdb_path, + free_solvent_reference_name=free_solvent_reference_name, + free_solvent_reference_path=free_solvent_reference_path, + target_solvent_molecules=target_solvent_molecules, + solvent_molecules_in_clusters=solvent_molecules_in_clusters, + free_solvent_molecules=free_solvent_molecules, + representative_structure_mode=representative_structure_mode, + representative_selection_mode=active_representatives.selection_mode, + settings=active_settings, + supplemental_entries=supplemental_entries, ) metadata = PackmolSetupMetadata( settings=active_settings, @@ -443,7 +581,9 @@ def build_packmol_setup( solvent_molecules_in_clusters=solvent_molecules_in_clusters, free_solvent_molecules=free_solvent_molecules, audit_report_path=str(audit_path), + build_report_path=str(build_report_path), entries=entries, + supplemental_entries=supplemental_entries, ) save_packmol_setup_metadata( project_source.rmcsetup_paths.packmol_setup_path, @@ -532,23 +672,368 @@ def _prepare_packmol_structure( structure: PDBStructure, *, residue_name: str, + solvent_residue_names: set[str] | frozenset[str] | None = None, + solute_reference_structure: PDBStructure | None = None, + expected_solute_element_counts: dict[str, int] | None = None, + solute_atom_count: int | None = None, +) -> _PreparedPackmolStructure: + copied_atoms = [atom.copy() for atom in structure.atoms] + solute_indices = _packmol_solute_atom_indices( + copied_atoms, + solvent_residue_names=solvent_residue_names, + solute_reference_structure=solute_reference_structure, + expected_solute_element_counts=expected_solute_element_counts, + solute_atom_count=solute_atom_count, + ) + solute_counters: dict[str, int] = {} + solvent_residue_numbers: dict[tuple[str, int], int] = {} + next_solvent_residue_number = 2 + solvent_residue_names: set[str] = set() + for atom_index, atom in enumerate(copied_atoms): + index = atom_index + 1 + atom.atom_id = index + atom.element = str(atom.element).title() + if atom_index in solute_indices: + atom.residue_number = 1 + atom.residue_name = residue_name + element_index = solute_counters.get(atom.element, 0) + 1 + solute_counters[atom.element] = element_index + atom.atom_name = f"{atom.element}{element_index}" + continue + + original_residue_name = _normalized_residue_name( + atom.residue_name or "SOL" + ) + original_residue_number = int(atom.residue_number) + residue_key = (original_residue_name, original_residue_number) + residue_number = solvent_residue_numbers.get(residue_key) + if residue_number is None: + residue_number = next_solvent_residue_number + solvent_residue_numbers[residue_key] = residue_number + next_solvent_residue_number += 1 + atom.residue_name = original_residue_name + atom.residue_number = residue_number + if not str(atom.atom_name).strip(): + atom.atom_name = f"{atom.element}{index}" + solvent_residue_names.add(original_residue_name) + + prepared_structure = PDBStructure( + atoms=copied_atoms, source_name=structure.source_name + ) + return _PreparedPackmolStructure( + structure=prepared_structure, + solute_atom_count=len(solute_indices), + solvent_atom_count=len(copied_atoms) - len(solute_indices), + solvent_residue_count=len(solvent_residue_numbers), + solvent_residue_names=tuple(sorted(solvent_residue_names)), + ) + + +def _packmol_solute_atom_indices( + atoms: list[PDBAtom], + *, + solvent_residue_names: set[str] | frozenset[str] | None = None, + solute_reference_structure: PDBStructure | None = None, + expected_solute_element_counts: dict[str, int] | None = None, + solute_atom_count: int | None = None, +) -> set[int]: + if not atoms: + return set() + known_solvent_residue_names = _normalized_residue_names( + solvent_residue_names or () + ) + if known_solvent_residue_names: + solute_indices = { + index + for index, atom in enumerate(atoms) + if _normalized_residue_name(atom.residue_name) + not in known_solvent_residue_names + } + if solute_indices: + return solute_indices + + matched_solute_indices = _coordinate_matched_solute_atom_indices( + atoms, + solute_reference_structure, + ) + if matched_solute_indices: + return matched_solute_indices + + expected_solute_indices = _element_matched_solute_atom_indices( + atoms, + expected_solute_element_counts or {}, + ) + if expected_solute_indices: + return expected_solute_indices + + if solute_atom_count is not None and solute_atom_count > 0: + bounded_count = min(int(solute_atom_count), len(atoms)) + return set(range(bounded_count)) + first_atom = atoms[0] + first_key = ( + str(first_atom.residue_name).strip().upper(), + int(first_atom.residue_number), + ) + return { + index + for index, atom in enumerate(atoms) + if ( + str(atom.residue_name).strip().upper(), + int(atom.residue_number), + ) + == first_key + } + + +def _solvent_residue_names_for_packmol_source( + solvent_metadata: SolventHandlingMetadata | None, + *, + free_solvent_reference_path: str | None, +) -> frozenset[str]: + residue_names: set[str] = set() + if solvent_metadata is not None: + _add_residue_name_candidate( + residue_names, + solvent_metadata.reference_residue_name, + ) + _add_residue_name_candidate( + residue_names, + solvent_metadata.reference_name, + ) + _add_reference_residue_name( + residue_names, + solvent_metadata.reference_path, + ) + _add_reference_residue_name(residue_names, free_solvent_reference_path) + return frozenset(residue_names) + + +def _add_residue_name_candidate( + residue_names: set[str], + value: object, +) -> None: + text = str(value or "").strip() + if text: + residue_names.add(_normalized_residue_name(text)) + + +def _add_reference_residue_name( + residue_names: set[str], + reference_path: str | None, +) -> None: + path_text = _optional_text(reference_path) + if path_text is None: + return + path = Path(path_text).expanduser() + if not path.is_file() or path.suffix.lower() != ".pdb": + return + try: + reference_structure = PDBStructure.from_file(path) + except Exception: + return + for atom in reference_structure.atoms: + _add_residue_name_candidate(residue_names, atom.residue_name) + return + + +def _solute_reference_structure_for_packmol_source( + solvent_entry: object | None, +) -> PDBStructure | None: + if solvent_entry is None: + return None + path_text = _optional_text(getattr(solvent_entry, "no_solvent_pdb", None)) + if path_text is None: + return None + path = Path(path_text).expanduser() + if not path.is_file(): + return None + try: + return PDBStructure.from_file(path) + except Exception: + return None + + +def _coordinate_matched_solute_atom_indices( + atoms: list[PDBAtom], + reference_structure: PDBStructure | None, +) -> set[int]: + if reference_structure is None or not reference_structure.atoms: + return set() + matched_indices: set[int] = set() + for reference_atom in reference_structure.atoms: + reference_element = str(reference_atom.element).title() + best_index: int | None = None + best_distance: float | None = None + for index, atom in enumerate(atoms): + if index in matched_indices: + continue + if str(atom.element).title() != reference_element: + continue + distance = float( + np.linalg.norm(atom.coordinates - reference_atom.coordinates) + ) + if distance > _PACKMOL_SOLUTE_MATCH_TOLERANCE_A: + continue + if best_distance is None or distance < best_distance: + best_index = index + best_distance = distance + if best_index is None: + return set() + matched_indices.add(best_index) + return matched_indices + + +def _element_matched_solute_atom_indices( + atoms: list[PDBAtom], + expected_counts: dict[str, int], +) -> set[int]: + expected = { + str(element).title(): int(count) + for element, count in expected_counts.items() + if int(count) > 0 + } + if not expected: + return set() + selected_indices: set[int] = set() + for element, expected_count in expected.items(): + matching_indices = [ + index + for index, atom in enumerate(atoms) + if str(atom.element).title() == element + ] + if len(matching_indices) != expected_count: + return set() + selected_indices.update(matching_indices) + return selected_indices + + +def _normalized_residue_names( + residue_names: set[str] | frozenset[str] | tuple[str, ...] | list[str], +) -> frozenset[str]: + normalized: set[str] = set() + for residue_name in residue_names: + text = str(residue_name or "").strip() + if text: + normalized.add(_normalized_residue_name(text)) + return frozenset(normalized) + + +def _solute_atom_count_for_packmol_source( + solvent_entry: object | None, +) -> int | None: + if solvent_entry is None: + return None + try: + value = int(getattr(solvent_entry, "atom_count_no_solvent", 0)) + except (TypeError, ValueError): + return None + return value if value > 0 else None + + +def _write_supplemental_packmol_structures( + output_dir: Path, + plan_metadata: PackmolPlanningMetadata, +) -> list[PackmolSetupSupplementalEntry]: + allocation = plan_metadata.supplemental_allocation + if allocation is None: + return [] + entries: list[PackmolSetupSupplementalEntry] = [] + output_dir.mkdir(parents=True, exist_ok=True) + for index, component in enumerate(allocation.entries, start=1): + if component.planned_count <= 0: + continue + source_structure = _supplemental_source_structure(component) + prepared_structure = _prepare_supplemental_packmol_structure( + source_structure, + residue_name=component.residue_name, + ) + filename = ( + f"supplemental_{index:03d}_" + f"{_safe_name(component.name)}_" + f"{_safe_name(component.residue_name)}.pdb" + ) + output_path = output_dir / filename + prepared_structure.write_pdb_file(output_path) + entries.append( + PackmolSetupSupplementalEntry( + role=component.role, + name=component.name, + source_type=component.source_type, + reference_name=component.reference_name, + reference_path=component.reference_path, + residue_name=component.residue_name, + planned_count=component.planned_count, + atom_count=len(prepared_structure.atoms), + element_counts=dict(component.element_counts), + packmol_pdb=str(output_path), + ) + ) + return entries + + +def _supplemental_source_structure( + component: object, +) -> PDBStructure: + source_type = str(getattr(component, "source_type", "")).strip() + reference_path = _optional_text(getattr(component, "reference_path", None)) + if source_type == "reference" and reference_path is not None: + return PDBStructure.from_file(Path(reference_path).expanduser()) + if source_type == "single_atom": + element_counts = dict(getattr(component, "element_counts", {}) or {}) + element = next( + ( + str(key) + for key, value in element_counts.items() + if int(value) > 0 + ), + "X", + ) + residue_name = _normalized_residue_name( + str(getattr(component, "residue_name", "") or element) + ) + return PDBStructure( + atoms=[ + PDBAtom( + atom_id=1, + atom_name=f"{element}1", + residue_name=residue_name, + residue_number=1, + coordinates=np.zeros(3, dtype=float), + element=element, + ) + ], + source_name=str(getattr(component, "name", "") or element), + ) + raise ValueError( + "Unsupported supplemental Packmol component source: " + f"{source_type or '(none)'}" + ) + + +def _prepare_supplemental_packmol_structure( + structure: PDBStructure, + *, + residue_name: str, ) -> PDBStructure: copied_atoms = [atom.copy() for atom in structure.atoms] + normalized_residue = _normalized_residue_name(residue_name) + counters: dict[str, int] = {} for index, atom in enumerate(copied_atoms, start=1): atom.atom_id = index + atom.element = str(atom.element).title() + atom.residue_name = normalized_residue atom.residue_number = 1 - atom.residue_name = residue_name - prepared = PDBStructure( - atoms=copied_atoms, source_name=structure.source_name - ) - prepared.rename_atom_names_by_element(reindex_serial=True) - return prepared + if not str(atom.atom_name).strip(): + counters[atom.element] = counters.get(atom.element, 0) + 1 + atom.atom_name = f"{atom.element}{counters[atom.element]}" + return PDBStructure(atoms=copied_atoms, source_name=structure.source_name) def _write_packmol_input( output_dir: Path, entries: list[PackmolSetupEntry], *, + supplemental_entries: list[PackmolSetupSupplementalEntry], solvent_pdb_path: str | None, free_solvent_molecules: int, box_side_length_a: float, @@ -568,6 +1053,14 @@ def _write_packmol_input( f"{box_side_length_a:.3f} {box_side_length_a:.3f} {box_side_length_a:.3f}\n" ) handle.write("end structure\n\n") + for entry in supplemental_entries: + handle.write(f"structure {Path(entry.packmol_pdb).name}\n") + handle.write(f" number {entry.planned_count}\n") + handle.write( + " inside box 0.0 0.0 0.0 " + f"{box_side_length_a:.3f} {box_side_length_a:.3f} {box_side_length_a:.3f}\n" + ) + handle.write("end structure\n\n") if solvent_pdb_path and free_solvent_molecules > 0: handle.write(f"structure {Path(solvent_pdb_path).name}\n") handle.write(f" number {free_solvent_molecules}\n") @@ -591,6 +1084,7 @@ def _write_packmol_audit_report( target_solvent_molecules: int, solvent_molecules_in_clusters: int, free_solvent_molecules: int, + supplemental_entries: list[PackmolSetupSupplementalEntry], ) -> Path: lines = [ "# Packmol Build Audit", @@ -621,6 +1115,10 @@ def _write_packmol_audit_report( f"- Box side: {plan_metadata.settings.box_side_length_a:.3f} A", f"- Cluster entries: {len(entries)}", f"- Total cluster count: {sum(entry.planned_count for entry in entries)}", + ( + "- Supplemental component count: " + f"{sum(entry.planned_count for entry in supplemental_entries)}" + ), "", "## Structure Table", "| Structure | Motif | Param | Count | Residue | File |", @@ -633,6 +1131,23 @@ def _write_packmol_audit_report( f"{entry.planned_count} | {entry.residue_name} | " f"{Path(entry.packmol_pdb).name} |" ) + if supplemental_entries: + lines.extend( + [ + "", + "## Supplemental Solute Components", + "| Name | Role | Residue | Count | Formula | File |", + "| --- | --- | --- | ---: | --- | --- |", + ] + ) + for entry in supplemental_entries: + lines.append( + "| " + f"{entry.name} | {entry.role} | {entry.residue_name} | " + f"{entry.planned_count} | " + f"{_format_element_counts(entry.element_counts)} | " + f"{Path(entry.packmol_pdb).name} |" + ) lines.extend( [ "", @@ -641,10 +1156,22 @@ def _write_packmol_audit_report( f"- Count-normalized weights: {project_source.rmcsetup_paths.planned_count_weights_csv_path}", f"- Atom-normalized weights: {project_source.rmcsetup_paths.planned_atom_weights_csv_path}", f"- Planning report: {project_source.rmcsetup_paths.packmol_plan_report_path}", + ( + "- Reproducibility report: " + f"{project_source.rmcsetup_paths.packmol_build_report_path}" + ), "", "## Notes", - "- Cluster PDBs were rewritten with unique residue names for Packmol use.", + ( + "- Cluster solute atoms were rewritten with unique residue " + "names for Packmol use." + ), + ( + "- Embedded solvent residues were preserved and reindexed as " + "separate solvent molecules." + ), "- Free solvent counts subtract solvent molecules already present in the cluster files from the bulk-solvent target.", + "- Supplemental solute components are placed as independent Packmol structures to complete solute stoichiometry not represented by the weighted cluster files.", "- If solvent-handling outputs are available, the completed full-solvent representative PDBs define the embedded cluster solvent counts.", ] ) @@ -653,6 +1180,276 @@ def _write_packmol_audit_report( return audit_path +def _write_packmol_build_report( + project_source: "RMCDreamProjectSource", + plan_metadata: PackmolPlanningMetadata, + entries: list[PackmolSetupEntry], + *, + input_path: Path, + solvent_pdb_path: str | None, + free_solvent_reference_name: str | None, + free_solvent_reference_path: str | None, + target_solvent_molecules: int, + solvent_molecules_in_clusters: int, + free_solvent_molecules: int, + representative_structure_mode: str, + representative_selection_mode: str, + settings: PackmolSetupSettings, + supplemental_entries: list[PackmolSetupSupplementalEntry], +) -> Path: + report_path = project_source.rmcsetup_paths.packmol_build_report_path + lines = [ + "SAXSShell rmcsetup Packmol build report", + f"Generated: {datetime.now().isoformat(timespec='seconds')}", + f"Project: {project_source.settings.project_dir}", + "", + "Source input information", + f" Packmol input file: {input_path}", + ( + " Packed output path: " + f"{input_path.parent / settings.packed_output_filename}" + ), + f" Planning metadata: {project_source.rmcsetup_paths.packmol_plan_path}", + f" Setup metadata: {project_source.rmcsetup_paths.packmol_setup_path}", + f" Representative selection mode: {representative_selection_mode}", + ( + " Representative structure set: " + f"{representative_structure_mode_label(representative_structure_mode)}" + ), + f" Planning mode: {plan_metadata.settings.planning_mode}", + f" Packmol tolerance: {settings.tolerance_angstrom:.6g} A", + "", + "Box and number-density targets", + ( + " Box side length: " + f"{plan_metadata.settings.box_side_length_a:.6g} A" + ), + ( + " Target total number density: " + f"{plan_metadata.target_total_number_density_a3:.8g} atoms/A^3" + ), + ( + " Achieved cluster number density: " + f"{plan_metadata.achieved_total_number_density_a3:.8g} atoms/A^3" + ), + " Target element number densities:", + ] + if plan_metadata.target_element_number_density_a3: + for element in sorted(plan_metadata.target_element_number_density_a3): + lines.append( + f" {element}: " + f"{plan_metadata.target_element_number_density_a3[element]:.8g} " + "atoms/A^3" + ) + else: + lines.append(" none") + lines.append(" Achieved element number densities:") + if plan_metadata.achieved_element_number_density_a3: + for element in sorted( + plan_metadata.achieved_element_number_density_a3 + ): + lines.append( + f" {element}: " + f"{plan_metadata.achieved_element_number_density_a3[element]:.8g} " + "atoms/A^3" + ) + else: + lines.append(" none") + + lines.extend( + [ + "", + "Solvent accounting", + ( + " Free solvent structure: " + f"{free_solvent_reference_name or '(none)'}" + ), + ( + " Free solvent source path: " + f"{free_solvent_reference_path or '(none)'}" + ), + f" Free solvent Packmol PDB: {solvent_pdb_path or '(none)'}", + f" Computed solvent molecules: {target_solvent_molecules}", + f" Cluster solvent molecules: {solvent_molecules_in_clusters}", + f" Free solvent molecules: {free_solvent_molecules}", + "", + "Representative cluster inputs", + ( + " Structure | Motif | Param | Source PDB | Packmol PDB | " + "Cluster residue | Count | Selected weight | Planned count " + "weight | Atom weight | Solute atoms | Solvent atoms | " + "Solvent residues" + ), + ] + ) + for entry in entries: + solvent_residues = ( + ",".join(entry.solvent_residue_names) + if entry.solvent_residue_names + else "none" + ) + lines.append( + " " + f"{entry.structure} | {entry.motif} | {entry.param} | " + f"{entry.source_pdb} | {entry.packmol_pdb} | " + f"{entry.residue_name} | {entry.planned_count} | " + f"{entry.selected_weight:.8g} | " + f"{entry.planned_count_weight:.8g} | " + f"{entry.planned_atom_weight:.8g} | " + f"{entry.solute_atom_count} | {entry.solvent_atom_count} | " + f"{solvent_residues}" + ) + + allocation = plan_metadata.solvent_allocation + if allocation is not None and allocation.entries: + lines.extend(["", "Cluster solvent allocation"]) + for allocation_entry in allocation.entries: + lines.append( + " " + f"{allocation_entry.structure}/{allocation_entry.motif} " + f"({allocation_entry.param}): " + f"{allocation_entry.planned_count} clusters x " + f"{allocation_entry.solvent_molecules_per_cluster} solvent " + "molecules per cluster = " + f"{allocation_entry.solvent_molecules_total}" + ) + + supplemental_allocation = plan_metadata.supplemental_allocation + if supplemental_allocation is not None: + lines.extend(["", "Supplemental solute accounting"]) + lines.append( + " Formula units represented by weighted clusters: " + f"{supplemental_allocation.target_solute_formula_units}" + ) + lines.append( + " Formula-unit basis: " + + _format_float_counts(supplemental_allocation.formula_unit_basis) + ) + lines.append( + " Cluster solute element totals: " + + _format_element_counts( + supplemental_allocation.cluster_solute_element_totals + ) + ) + lines.append( + " Target solute element totals: " + + _format_element_counts( + supplemental_allocation.target_solute_element_totals + ) + ) + lines.append( + " Missing solute elements before supplementals: " + + _format_element_counts( + supplemental_allocation.missing_solute_element_totals + ) + ) + lines.append( + " Added solute elements: " + + _format_element_counts( + supplemental_allocation.added_solute_element_totals + ) + ) + lines.append( + " Unfilled solute elements: " + + _format_element_counts( + supplemental_allocation.unfilled_solute_element_totals + ) + ) + for warning in supplemental_allocation.warnings: + lines.append(f" Warning: {warning}") + if supplemental_entries: + lines.extend( + [ + "", + "Supplemental Packmol components", + ( + " Name | Role | Source | Residue | Count | Atom count | " + "Formula | Packmol PDB" + ), + ] + ) + for entry in supplemental_entries: + source = ( + entry.reference_path + if entry.reference_path is not None + else entry.source_type + ) + lines.append( + " " + f"{entry.name} | {entry.role} | {source} | " + f"{entry.residue_name} | {entry.planned_count} | " + f"{entry.atom_count} | " + f"{_format_element_counts(entry.element_counts)} | " + f"{entry.packmol_pdb}" + ) + + lines.extend( + [ + "", + "Generated files", + f" Packmol input: {input_path}", + ( + " Supplemental PDBs: " + + ( + ", ".join( + Path(entry.packmol_pdb).name + for entry in supplemental_entries + ) + if supplemental_entries + else "none" + ) + ), + f" Free solvent PDB: {solvent_pdb_path or '(none)'}", + ( + " Packmol inputs directory: " + f"{project_source.rmcsetup_paths.packmol_inputs_dir}" + ), + ( + " Count report: " + f"{project_source.rmcsetup_paths.cluster_counts_csv_path}" + ), + ( + " Count-normalized weights: " + f"{project_source.rmcsetup_paths.planned_count_weights_csv_path}" + ), + ( + " Atom-normalized weights: " + f"{project_source.rmcsetup_paths.planned_atom_weights_csv_path}" + ), + ( + " Planning report: " + f"{project_source.rmcsetup_paths.packmol_plan_report_path}" + ), + ( + " Audit report: " + f"{project_source.rmcsetup_paths.packmol_audit_report_path}" + ), + "", + "Residue and constraint notes", + ( + " Cluster solute atoms are assigned the cluster-specific " + "residues listed above." + ), + ( + " Embedded solvent atoms keep solvent residue names and are " + "reindexed by solvent molecule." + ), + ( + " Constraint generation filters to the cluster-specific " + "residues, so embedded solvent residues are not used for " + "solute cluster constraints." + ), + ( + " Supplemental solute residues are independent Packmol " + "components and are not used for cluster-specific constraint " + "generation." + ), + ] + ) + report_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return report_path + + def _packmol_residue_code(index: int) -> str: alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" if index < 0: @@ -723,6 +1520,28 @@ def _safe_filename(text: str) -> str: return name or "item" +def _format_element_counts(counts: dict[str, int]) -> str: + if not counts: + return "none" + return ( + ", ".join( + f"{element} x{int(count)}" + for element, count in sorted(counts.items()) + if int(count) != 0 + ) + or "none" + ) + + +def _format_float_counts(counts: dict[str, float]) -> str: + if not counts: + return "none" + return ", ".join( + f"{element}={float(value):.6g}" + for element, value in sorted(counts.items()) + ) + + def _resolve_free_solvent_reference( settings: PackmolSetupSettings, plan_metadata: PackmolPlanningMetadata, @@ -757,6 +1576,7 @@ def _resolve_free_solvent_reference( "PackmolSetupEntry", "PackmolSetupMetadata", "PackmolSetupSettings", + "PackmolSetupSupplementalEntry", "build_packmol_setup", "load_packmol_setup_metadata", "save_packmol_setup_metadata", diff --git a/src/saxshell/fullrmc/project_model.py b/src/saxshell/fullrmc/project_model.py index 1aad673..ac9a6ab 100644 --- a/src/saxshell/fullrmc/project_model.py +++ b/src/saxshell/fullrmc/project_model.py @@ -36,6 +36,7 @@ class RMCSetupPaths: constraint_generation_path: Path packmol_plan_report_path: Path packmol_audit_report_path: Path + packmol_build_report_path: Path cluster_counts_csv_path: Path planned_count_weights_csv_path: Path planned_atom_weights_csv_path: Path @@ -92,6 +93,9 @@ def build_rmcsetup_paths( packmol_audit_report_path=rmcsetup_dir / "reports" / "packmol_audit.md", + packmol_build_report_path=rmcsetup_dir + / "reports" + / "packmol_build_report.txt", cluster_counts_csv_path=rmcsetup_dir / "reports" / "cluster_counts.csv", diff --git a/src/saxshell/fullrmc/solvent_handling.py b/src/saxshell/fullrmc/solvent_handling.py index 9206c33..c43e5d2 100644 --- a/src/saxshell/fullrmc/solvent_handling.py +++ b/src/saxshell/fullrmc/solvent_handling.py @@ -416,6 +416,7 @@ class RepresentativeSolventAnalysisEntry: source_file: str source_status: str analysis_result: "SolventShellAnalysisResult" + included_in_distribution_status: bool = True @property def representative_label(self) -> str: @@ -440,6 +441,18 @@ class RepresentativeSolventDistributionAnalysis: aggregate_solute_element_counts: dict[str, int] entries: list[RepresentativeSolventAnalysisEntry] + @property + def distribution_status_entry_count(self) -> int: + return sum( + 1 + for entry in self.entries + if entry.included_in_distribution_status + ) + + @property + def ignored_distribution_status_entry_count(self) -> int: + return len(self.entries) - self.distribution_status_entry_count + @property def build_required(self) -> bool: return self.distribution_status != "complete_solvent" @@ -453,8 +466,16 @@ def summary_text(self) -> str: "Detected representative distribution state: " f"{_solvent_state_text(self.distribution_status)}" ), - f"Representative entries analyzed: {len(self.entries)}", + ( + "Representative entries analyzed: " + f"{self.distribution_status_entry_count}" + ), ] + if self.ignored_distribution_status_entry_count: + lines.append( + "Single-atom representatives ignored for distribution state: " + f"{self.ignored_distribution_status_entry_count}" + ) if self.distribution_note: lines.append(self.distribution_note) if self.aggregate_solute_element_counts: @@ -470,9 +491,14 @@ def summary_text(self) -> str: if self.entries: lines.extend(["", "Detected representative states:"]) for entry in self.entries: + suffix = ( + " (ignored for distribution state)" + if not entry.included_in_distribution_status + else "" + ) lines.append( f" {entry.representative_label}: " - f"{entry.source_status_text}" + f"{entry.source_status_text}{suffix}" ) return "\n".join(lines) @@ -913,6 +939,12 @@ def analyze_representative_solvent_distribution( reference_match_tolerance_a=settings.reference_match_tolerance_a, ) source_status = _classify_source_solvent_status(analysis_result) + included_in_distribution_status = ( + not _representative_solvent_status_is_single_atom( + representative_entry, + analysis_result, + ) + ) entries.append( RepresentativeSolventAnalysisEntry( structure=representative_entry.structure, @@ -921,6 +953,9 @@ def analyze_representative_solvent_distribution( source_file=representative_entry.source_file, source_status=source_status, analysis_result=analysis_result, + included_in_distribution_status=( + included_in_distribution_status + ), ) ) aggregate_solute_counts.update(analysis_result.solute_element_counts) @@ -1250,37 +1285,87 @@ def _classify_source_solvent_status( def _resolve_distribution_status( entries: list[RepresentativeSolventAnalysisEntry], ) -> tuple[str, str]: + ignored_count = sum( + 1 for entry in entries if not entry.included_in_distribution_status + ) + status_entries = [ + entry for entry in entries if entry.included_in_distribution_status + ] + ignored_note = "" + if ignored_count: + ignored_note = ( + f"Ignored {ignored_count} single-atom representative " + "structure(s) when determining the overall solvent state." + ) statuses = { entry.source_status - for entry in entries + for entry in status_entries if str(entry.source_status).strip() } if not statuses: return ( - "unknown", - "No representative solvent states were available.", + "no_solvent", + ignored_note or "No representative solvent states were available.", ) if statuses == {"complete_solvent"}: return ( "complete_solvent", - "Every representative structure already contains complete solvent molecules, so the existing solvent-decorated structures can be passed through.", + _join_distribution_notes( + "Every representative structure already contains complete " + "solvent molecules, so the existing solvent-decorated " + "structures can be passed through.", + ignored_note, + ), ) if statuses == {"partial_solvent"}: return ( "partial_solvent", - "Every representative structure contains partial solvent molecules, so the saved anchors will be used to rebuild complete solvent molecules.", + _join_distribution_notes( + "Every representative structure contains partial solvent " + "molecules, so the saved anchors will be used to rebuild " + "complete solvent molecules.", + ignored_note, + ), ) if statuses == {"no_solvent"}: return ( "no_solvent", - "No representative structure contains coordinated solvent molecules, so solvent shells will be built from the stripped solute structures.", + _join_distribution_notes( + "No representative structure contains coordinated solvent " + "molecules, so solvent shells will be built from the stripped " + "solute structures.", + ignored_note, + ), ) return ( "no_solvent", - "Representative solvent detections were inconsistent across the saved structures. Following the conservative workflow rule, the current cluster distribution is treated as having no coordinated solvent.", + _join_distribution_notes( + "Representative solvent detections were inconsistent across the " + "saved structures. Following the conservative workflow rule, the " + "current cluster distribution is treated as having no coordinated " + "solvent.", + ignored_note, + ), ) +def _representative_solvent_status_is_single_atom( + representative_entry: object, + analysis_result: "SolventShellAnalysisResult", +) -> bool: + try: + atom_count = int(getattr(representative_entry, "atom_count")) + except Exception: + atom_count = 0 + if atom_count > 0: + return atom_count <= 1 + return int(analysis_result.total_atoms) <= 1 + + +def _join_distribution_notes(*notes: str) -> str: + return " ".join(str(note).strip() for note in notes if str(note).strip()) + + def _solvent_state_text(status: str) -> str: mapping = { "complete_solvent": "Complete solvent molecules detected", diff --git a/src/saxshell/fullrmc/ui/main_window.py b/src/saxshell/fullrmc/ui/main_window.py index c759845..b813170 100644 --- a/src/saxshell/fullrmc/ui/main_window.py +++ b/src/saxshell/fullrmc/ui/main_window.py @@ -3,6 +3,7 @@ import argparse import json import re +import shutil import sys from datetime import datetime from pathlib import Path @@ -67,6 +68,7 @@ from saxshell.fullrmc.packmol_planning import ( PackmolPlanningMetadata, PackmolPlanningSettings, + PackmolSupplementalComponentSettings, build_packmol_plan, ) from saxshell.fullrmc.packmol_setup import ( @@ -82,8 +84,11 @@ from saxshell.fullrmc.representatives import ( RepresentativeSelectionMetadata, RepresentativeSelectionSettings, + build_distribution_selection, build_representative_preview_clusters, representative_source_solvent_mode_to_variant, + save_distribution_selection_metadata, + save_representative_selection_metadata, select_distribution_representatives, select_first_file_representatives, ) @@ -160,6 +165,7 @@ load_saxshell_icon, prepare_saxshell_application_identity, ) +from saxshell.ui.periodic_table import PeriodicTableElementDialog _OPEN_WINDOWS: list["RMCSetupMainWindow"] = [] @@ -667,6 +673,14 @@ def __init__( self.output_group = QGroupBox("RMCSetup Output Structure") output_layout = QVBoxLayout(self.output_group) + output_button_row = QHBoxLayout() + self.reset_rmcsetup_state_button = QPushButton("Reset RMCSetup State") + self.reset_rmcsetup_state_button.clicked.connect( + self._reset_complete_rmcsetup_state + ) + output_button_row.addWidget(self.reset_rmcsetup_state_button) + output_button_row.addStretch(1) + output_layout.addLayout(output_button_row) self.output_summary_box = QPlainTextEdit() self.output_summary_box.setReadOnly(True) self.output_summary_box.setMinimumHeight(160) @@ -1339,6 +1353,15 @@ def __init__( representative_button_row.addWidget( self.preview_representatives_button ) + self.reset_representative_state_button = QPushButton( + "Reset Representative Analysis" + ) + self.reset_representative_state_button.clicked.connect( + self._reset_representative_analysis_state + ) + representative_button_row.addWidget( + self.reset_representative_state_button + ) representative_button_row.addStretch(1) representative_content_layout.addLayout(representative_button_row) @@ -1580,7 +1603,7 @@ def __init__( generated_pdb_button_row.addStretch(1) generated_pdb_layout.addLayout(generated_pdb_button_row) - self.generated_pdb_table = QTableWidget(0, 6) + self.generated_pdb_table = QTableWidget(0, 8) self.generated_pdb_table.setHorizontalHeaderLabels( [ "Representative", @@ -1589,6 +1612,8 @@ def __init__( "Structure File", "Atoms", "Source", + "DREAM Weight", + "DREAM Value", ] ) self.generated_pdb_table.setSelectionBehavior( @@ -1688,6 +1713,55 @@ def __init__( ) packmol_content_layout.addLayout(packmol_form) + self.packmol_supplemental_table = QTableWidget(0, 4) + self.packmol_supplemental_table.setHorizontalHeaderLabels( + ["Role", "Source", "Reference/Element", "Residue"] + ) + self.packmol_supplemental_table.setSelectionBehavior( + QTableWidget.SelectionBehavior.SelectRows + ) + self.packmol_supplemental_table.setSelectionMode( + QTableWidget.SelectionMode.SingleSelection + ) + self.packmol_supplemental_table.setEditTriggers( + QTableWidget.EditTrigger.NoEditTriggers + ) + self.packmol_supplemental_table.horizontalHeader().setStretchLastSection( + True + ) + packmol_content_layout.addWidget(self.packmol_supplemental_table) + + supplemental_button_row = QHBoxLayout() + self.add_packmol_supplemental_reference_button = QPushButton( + "Add Reference Component" + ) + self.add_packmol_supplemental_reference_button.clicked.connect( + self._add_packmol_supplemental_reference_component + ) + supplemental_button_row.addWidget( + self.add_packmol_supplemental_reference_button + ) + self.add_packmol_supplemental_atom_button = QPushButton( + "Add Single Atom" + ) + self.add_packmol_supplemental_atom_button.clicked.connect( + self._add_packmol_supplemental_atom_component + ) + supplemental_button_row.addWidget( + self.add_packmol_supplemental_atom_button + ) + self.remove_packmol_supplemental_button = QPushButton( + "Remove Selected" + ) + self.remove_packmol_supplemental_button.clicked.connect( + self._remove_selected_packmol_supplemental_component + ) + supplemental_button_row.addWidget( + self.remove_packmol_supplemental_button + ) + supplemental_button_row.addStretch(1) + packmol_content_layout.addLayout(supplemental_button_row) + packmol_button_row = QHBoxLayout() self.compute_packmol_plan_button = QPushButton( "Compute Cluster Counts" @@ -1885,6 +1959,148 @@ def _reload_saved_representative_structures(self) -> None: self._append_run_log("Reloading saved representative structures.") self._refresh_project_source() + def _reset_representative_analysis_state(self) -> None: + self._reset_representative_dependent_state( + confirm=True, + refresh=True, + clear_reason="Representative analysis reset requested.", + ) + + def _reset_representative_dependent_state( + self, + *, + confirm: bool, + refresh: bool, + clear_reason: str, + ) -> bool: + state = self._project_source_state + if state is None: + if confirm: + QMessageBox.information( + self, + "No SAXS project loaded", + "Load a SAXS project before resetting representative " + "analysis.", + ) + return False + if confirm: + response = QMessageBox.question( + self, + "Reset representative analysis?", + ( + "Clear solvent-state analysis, generated PDB outputs " + "tracked by that analysis, Packmol planning, Packmol " + "setup, and generated constraints for this project? " + "Saved representative selections will be kept." + ), + ) + if response != QMessageBox.StandardButton.Yes: + return False + + self._delete_tracked_solvent_output_files(state) + for path in ( + state.rmcsetup_paths.solvent_handling_path, + state.rmcsetup_paths.packmol_plan_path, + state.rmcsetup_paths.packmol_setup_path, + state.rmcsetup_paths.constraint_generation_path, + ): + self._write_empty_json_file(path) + self._clear_directory_contents(state.rmcsetup_paths.packmol_inputs_dir) + self._clear_directory_contents(state.rmcsetup_paths.constraints_dir) + self._clear_directory_contents(state.rmcsetup_paths.reports_dir) + state.solvent_handling = None + state.packmol_planning = None + state.packmol_setup = None + state.constraint_generation = None + self._solvent_distribution_analysis = None + self._append_run_log( + "Cleared representative-dependent rmcsetup state. " + clear_reason + ) + if refresh: + self._refresh_project_source() + return True + + def _reset_complete_rmcsetup_state(self) -> None: + state = self._project_source_state + if state is None: + QMessageBox.information( + self, + "No SAXS project loaded", + "Load a SAXS project before resetting rmcsetup state.", + ) + return + response = QMessageBox.question( + self, + "Reset all rmcsetup state?", + ( + "Clear the entire rmcsetup folder for this project, including " + "saved representative structures, solution properties, " + "solvent handling, Packmol inputs, reports, and generated " + "constraints?" + ), + ) + if response != QMessageBox.StandardButton.Yes: + return + self._clear_directory_contents(state.rmcsetup_paths.rmcsetup_dir) + self._append_run_log("Cleared the complete rmcsetup state.") + self._refresh_project_source() + + def _delete_tracked_solvent_output_files( + self, + state: RMCDreamProjectSource, + ) -> None: + solvent_metadata = state.solvent_handling + if solvent_metadata is None: + return + representative_sources = { + Path(entry.source_file).expanduser().resolve() + for entry in ( + state.representative_selection.representative_entries + if state.representative_selection is not None + else [] + ) + if str(entry.source_file).strip() + } + for entry in solvent_metadata.entries: + for raw_path in (entry.no_solvent_pdb, entry.completed_pdb): + path = Path(raw_path).expanduser().resolve() + if path in representative_sources: + continue + if path.is_file() or path.is_symlink(): + try: + path.unlink() + except OSError: + continue + self._remove_empty_parents( + path.parent, + stop_at=state.rmcsetup_paths.rmcsetup_dir, + ) + + @staticmethod + def _write_empty_json_file(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("{}\n", encoding="utf-8") + + @staticmethod + def _clear_directory_contents(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + for child in list(path.iterdir()): + if child.is_symlink() or child.is_file(): + child.unlink() + elif child.is_dir(): + shutil.rmtree(child) + + @staticmethod + def _remove_empty_parents(path: Path, *, stop_at: Path) -> None: + stop = stop_at.resolve() + current = path.resolve() + while current != stop and stop in current.parents: + try: + current.rmdir() + except OSError: + break + current = current.parent + def _handle_representative_structure_results_changed( self, project_dir_text: str, @@ -2090,8 +2306,24 @@ def _refresh_project_source(self) -> None: ) except Exception as exc: self._append_run_log(f"Unable to load project source: {exc}") + self.reset_rmcsetup_state_button.setEnabled( + self._project_source_state is not None + ) self._populate_dream_controls() self._populate_favorite_controls() + self._initialize_selection_from_state() + representative_weights_changed = ( + self._sync_representatives_with_current_dream_weights() + ) + if representative_weights_changed: + self._reset_representative_dependent_state( + confirm=False, + refresh=False, + clear_reason=( + "DREAM weight mapping changed for saved representatives." + ), + ) + self._clear_stale_representative_dependent_state() self._populate_solution_properties_controls() self._populate_representative_controls() self._populate_solvent_controls() @@ -2100,11 +2332,201 @@ def _refresh_project_source(self) -> None: self.project_summary_box.setPlainText(self._project_summary_text()) self.output_summary_box.setPlainText(self._output_structure_text()) self.favorite_summary_box.setPlainText(self._favorite_summary_text()) - self._initialize_selection_from_state() self._refresh_dream_source_summary() self._update_readiness_progress() self._set_task_progress("Project source loaded.", 100) + def _sync_representatives_with_current_dream_weights(self) -> bool: + state = self._project_source_state + if state is None or state.representative_selection is None: + return False + selection = self.current_selection() + if selection is None: + return False + metadata = state.representative_selection + try: + distribution = build_distribution_selection( + state, + selection, + selection_mode=metadata.selection_mode or "rmcsetup", + ) + except Exception as exc: + self._append_run_log( + "Unable to map saved representative structures to the " + f"current DREAM weights: {exc}" + ) + return False + + lookup, duplicate_keys = self._dream_distribution_lookup( + distribution.entries + ) + changed = self._selection_signature( + metadata.selection + ) != self._selection_signature( + selection + ) or self._distribution_signature( + metadata.distribution_selection + ) != self._distribution_signature( + distribution + ) + unmatched_labels: list[str] = [] + ambiguous_labels: list[str] = [] + for entry in metadata.representative_entries: + key = (entry.structure, entry.motif) + if key in duplicate_keys: + ambiguous_labels.append(self._representative_label(entry)) + continue + distribution_entry = lookup.get(key) + if distribution_entry is None: + unmatched_labels.append(self._representative_label(entry)) + continue + if ( + entry.param != distribution_entry.param + or abs( + float(entry.selected_weight) + - float(distribution_entry.selected_weight) + ) + > 1e-12 + or int(entry.cluster_count) + != int(distribution_entry.cluster_count) + ): + changed = True + entry.param = distribution_entry.param + entry.selected_weight = float(distribution_entry.selected_weight) + entry.cluster_count = int(distribution_entry.cluster_count) + + if unmatched_labels: + self._append_run_log( + "Saved representatives without a current DREAM weight: " + + ", ".join(unmatched_labels) + ) + if ambiguous_labels: + self._append_run_log( + "Saved representatives with ambiguous DREAM weights: " + + ", ".join(ambiguous_labels) + ) + + if not changed: + return False + + metadata.selection = selection + metadata.distribution_selection = distribution + metadata.updated_at = datetime.now().isoformat(timespec="seconds") + save_distribution_selection_metadata( + state.rmcsetup_paths.distribution_selection_path, + distribution, + ) + save_representative_selection_metadata( + state.rmcsetup_paths.representative_selection_path, + metadata, + ) + state.representative_selection = metadata + self._append_run_log( + "Mapped saved representative structures to the current DREAM " + "weight parameters." + ) + return True + + @staticmethod + def _dream_distribution_lookup( + entries: list[object], + ) -> tuple[dict[tuple[str, str], object], set[tuple[str, str]]]: + lookup: dict[tuple[str, str], object] = {} + duplicate_keys: set[tuple[str, str]] = set() + for entry in entries: + key = ( + str(getattr(entry, "structure", "")).strip(), + str(getattr(entry, "motif", "no_motif")).strip() or "no_motif", + ) + if not key[0]: + continue + if key in lookup: + duplicate_keys.add(key) + lookup[key] = entry + return lookup, duplicate_keys + + @staticmethod + def _selection_signature( + selection: DreamBestFitSelection, + ) -> tuple[object, ...]: + return ( + selection.run_name, + selection.run_relative_path, + selection.bestfit_method, + selection.posterior_filter_mode, + float(selection.posterior_top_percent), + int(selection.posterior_top_n), + float(selection.credible_interval_low), + float(selection.credible_interval_high), + selection.template_name, + selection.model_name, + ) + + @staticmethod + def _distribution_signature(metadata: object) -> tuple[object, ...]: + entries = getattr(metadata, "entries", ()) + return tuple( + ( + str(getattr(entry, "param", "")).strip(), + str(getattr(entry, "structure", "")).strip(), + str(getattr(entry, "motif", "no_motif")).strip() or "no_motif", + round(float(getattr(entry, "selected_weight", 0.0)), 12), + int(getattr(entry, "cluster_count", 0)), + bool(getattr(entry, "is_active", False)), + ) + for entry in entries + ) + + @staticmethod + def _representative_label(entry: object) -> str: + structure = str(getattr(entry, "structure", "")).strip() + motif = str(getattr(entry, "motif", "no_motif")).strip() + if not motif or motif == "no_motif": + return structure + return f"{structure}/{motif}" + + def _clear_stale_representative_dependent_state(self) -> None: + state = self._project_source_state + metadata = ( + state.representative_selection if state is not None else None + ) + if metadata is None: + return + metadata_updated_at = str(metadata.updated_at or "").strip() + downstream_metadata = ( + state.solvent_handling, + state.packmol_planning, + state.packmol_setup, + state.constraint_generation, + ) + if any( + self._timestamp_is_newer( + metadata_updated_at, + str(getattr(item, "updated_at", "") or "").strip(), + ) + for item in downstream_metadata + if item is not None + ): + self._reset_representative_dependent_state( + confirm=False, + refresh=False, + clear_reason=( + "Saved representative metadata is newer than downstream " + "rmcsetup outputs." + ), + ) + + @staticmethod + def _timestamp_is_newer(candidate: str, reference: str) -> bool: + if not candidate or not reference: + return False + try: + return datetime.fromisoformat(candidate) > datetime.fromisoformat( + reference + ) + except ValueError: + return candidate > reference + def _populate_dream_controls(self) -> None: self._updating_dream_controls = True try: @@ -2248,6 +2670,7 @@ def _output_structure_text(self) -> str: ("Constraint metadata: " f"{paths.constraint_generation_path}"), ("Packmol plan report: " f"{paths.packmol_plan_report_path}"), ("Packmol audit report: " f"{paths.packmol_audit_report_path}"), + ("Packmol build report: " f"{paths.packmol_build_report_path}"), ("Cluster counts report: " f"{paths.cluster_counts_csv_path}"), ("Packmol input file: " f"{paths.packmol_input_path}"), ("Merged constraints file: " f"{paths.merged_constraints_path}"), @@ -2565,6 +2988,11 @@ def _configure_tooltips(self) -> None: self.refresh_button, "Reload the selected project folder and rescan its saved data.", ) + self._set_widget_tooltip( + self.reset_rmcsetup_state_button, + "Clear the generated rmcsetup metadata and outputs for this " + "project after confirmation.", + ) self._set_widget_tooltip( self.dream_run_combo, "Choose the DREAM run whose model fit and weights should drive " @@ -2776,6 +3204,13 @@ def _configure_tooltips(self) -> None: self.preview_representatives_button, "Reload the saved representative structures from this project.", ) + self._set_widget_tooltip( + self.reset_representative_state_button, + "Clear solvent-state analysis, generated representative PDBs " + "tracked by that analysis, Packmol planning, Packmol setup, " + "and generated constraints while keeping the saved representative " + "selection.", + ) self._set_widget_tooltip( self.solvent_reference_source_combo, "Choose whether the solvent reference structure comes from a " @@ -2857,6 +3292,21 @@ def _configure_tooltips(self) -> None: "Choose the solvent structure file used for the free bulk " "solvent population in the Packmol box.", ) + self._set_widget_tooltip( + self.packmol_supplemental_table, + "List extra solute or solvent components that are not part of " + "the weighted representative cluster files.", + ) + self._set_widget_tooltip( + self.add_packmol_supplemental_reference_button, + "Add a reference molecule, such as an organic cation, for " + "Packmol stoichiometry completion.", + ) + self._set_widget_tooltip( + self.add_packmol_supplemental_atom_button, + "Add a single unclustered atom, such as Cs, for Packmol " + "stoichiometry completion.", + ) self._set_widget_tooltip( self.compute_packmol_plan_button, "Compute cluster counts and target weights for the current " @@ -2939,6 +3389,7 @@ def _populate_representative_controls(self) -> None: self._apply_representative_metadata(None) self.compute_representatives_button.setEnabled(False) self.preview_representatives_button.setEnabled(False) + self.reset_representative_state_button.setEnabled(False) self.representative_status_label.setText( "Representative structures: no SAXS project loaded." ) @@ -2950,6 +3401,13 @@ def _populate_representative_controls(self) -> None: self._apply_representative_metadata(state.representative_selection) self.compute_representatives_button.setEnabled(True) self.preview_representatives_button.setEnabled(True) + self.reset_representative_state_button.setEnabled( + state.representative_selection is not None + or state.solvent_handling is not None + or state.packmol_planning is not None + or state.packmol_setup is not None + or state.constraint_generation is not None + ) if state.representative_selection is None: self.representative_status_label.setText( "Representative structures: no saved project set loaded." @@ -2994,6 +3452,10 @@ def _populate_packmol_planning_controls(self) -> None: self.compute_packmol_plan_button.setEnabled(False) self.build_packmol_setup_button.setEnabled(False) self.packmol_free_solvent_combo.setEnabled(False) + self.packmol_supplemental_table.setEnabled(False) + self.add_packmol_supplemental_reference_button.setEnabled(False) + self.add_packmol_supplemental_atom_button.setEnabled(False) + self.remove_packmol_supplemental_button.setEnabled(False) self.packmol_plan_summary_box.setPlainText( "Load a SAXS project, calculate solution properties, and " "save representative structures before planning Packmol " @@ -3021,6 +3483,12 @@ def _populate_packmol_planning_controls(self) -> None: self.packmol_free_solvent_combo.setEnabled( self.packmol_free_solvent_combo.count() > 0 ) + self.packmol_supplemental_table.setEnabled(True) + self.add_packmol_supplemental_reference_button.setEnabled(True) + self.add_packmol_supplemental_atom_button.setEnabled(True) + self.remove_packmol_supplemental_button.setEnabled( + self.packmol_supplemental_table.rowCount() > 0 + ) self.packmol_plan_summary_box.setPlainText( self._packmol_plan_summary_text(state.packmol_planning) ) @@ -3751,14 +4219,17 @@ def _refresh_generated_pdb_browser( preview_path.name, str(atom_count), source_text, + representative_entry.param or "n/a", + f"{representative_entry.selected_weight:.6g}", ] details_lines = [ f"Representative: {values[0]}", f"Detected source solvent state: {detected_state}", f"Active structure set: {active_mode_label}", f"Structure file: {preview_path}", + ("DREAM weight: " f"{representative_entry.param or 'n/a'}"), ( - "Selected weight: " + "DREAM-derived weight value: " f"{representative_entry.selected_weight:.6g}" ), f"Cluster count: {representative_entry.cluster_count}", @@ -4022,13 +4493,21 @@ def _update_solvent_status_panel( ) ) status_lines = [ - f"Representative entries analyzed: {len(analysis.entries)}", + ( + "Representative entries analyzed: " + f"{analysis.distribution_status_entry_count}" + ), ( "Saved full-solvent representatives are not available yet. " "Build solvent-decorated representative PDBs to store the " "Full solvent representative structure set." ), ] + if analysis.ignored_distribution_status_entry_count: + status_lines.append( + "Single-atom representatives ignored for distribution state: " + f"{analysis.ignored_distribution_status_entry_count}" + ) if analysis.distribution_note: status_lines.append(analysis.distribution_note) status_lines.append( @@ -4168,6 +4647,9 @@ def _apply_packmol_planning_metadata( settings.planning_mode, ) self.packmol_box_side_spin.setValue(settings.box_side_length_a) + self._set_packmol_supplemental_components( + settings.supplemental_components + ) selected_reference = settings.free_solvent_reference if ( selected_reference is None @@ -4551,6 +5033,7 @@ def _finish_representative_selection( return state.representative_selection = metadata + state.solvent_handling = None state.packmol_planning = None state.packmol_setup = None state.constraint_generation = None @@ -4808,6 +5291,156 @@ def _run_representative_solvent_analysis( self._append_run_log(log_completion) return analysis + def _add_packmol_supplemental_reference_component(self) -> None: + choices = [preset.name for preset in self._available_solvent_presets] + if not choices: + QMessageBox.information( + self, + "No reference molecules", + "No xyz2pdb reference molecules are available.", + ) + return + reference_name, accepted = QInputDialog.getItem( + self, + "Add Reference Component", + "Reference molecule", + choices, + 0, + False, + ) + if not accepted or not reference_name: + return + role_label, accepted = QInputDialog.getItem( + self, + "Component Role", + "Role", + ["solute", "solvent"], + 0, + False, + ) + if not accepted: + return + preset_lookup = { + preset.name: preset for preset in self._available_solvent_presets + } + preset = preset_lookup.get(reference_name) + default_residue = "" if preset is None else preset.residue_name + residue_name, accepted = QInputDialog.getText( + self, + "Component Residue", + "Residue name", + text=default_residue, + ) + if not accepted: + return + reference_path = ( + None if preset is None else str(Path(preset.path).resolve()) + ) + self._append_packmol_supplemental_component( + PackmolSupplementalComponentSettings( + role=str(role_label or "solute"), + reference=reference_path or reference_name, + residue_name=str(residue_name or default_residue).strip(), + name=reference_name, + ) + ) + + def _add_packmol_supplemental_atom_component(self) -> None: + element = PeriodicTableElementDialog.get_element_symbol( + parent=self, + title="Add Single Atom", + ) + if not element: + return + role_label, accepted = QInputDialog.getItem( + self, + "Component Role", + "Role", + ["solute", "solvent"], + 0, + False, + ) + if not accepted: + return + residue_name, accepted = QInputDialog.getText( + self, + "Component Residue", + "Residue name", + text=element.upper()[:3], + ) + if not accepted: + return + self._append_packmol_supplemental_component( + PackmolSupplementalComponentSettings( + role=str(role_label or "solute"), + element=element, + residue_name=residue_name.strip(), + name=element, + ) + ) + + def _remove_selected_packmol_supplemental_component(self) -> None: + selected_rows = sorted( + { + index.row() + for index in self.packmol_supplemental_table.selectedIndexes() + }, + reverse=True, + ) + for row in selected_rows: + self.packmol_supplemental_table.removeRow(row) + self.remove_packmol_supplemental_button.setEnabled( + self.packmol_supplemental_table.rowCount() > 0 + ) + + def _append_packmol_supplemental_component( + self, + component: PackmolSupplementalComponentSettings, + ) -> None: + row = self.packmol_supplemental_table.rowCount() + self.packmol_supplemental_table.insertRow(row) + source_label = "Reference" if component.reference else "Atom" + identifier = component.reference or component.element or "" + identifier_label = ( + Path(identifier).stem if component.reference else identifier + ) + values = [ + component.role, + source_label, + identifier_label, + component.residue_name, + ] + for column, value in enumerate(values): + item = QTableWidgetItem(str(value)) + if column == 0: + item.setData(Qt.ItemDataRole.UserRole, component) + self.packmol_supplemental_table.setItem(row, column, item) + self.remove_packmol_supplemental_button.setEnabled(True) + + def _set_packmol_supplemental_components( + self, + components: tuple[PackmolSupplementalComponentSettings, ...], + ) -> None: + self.packmol_supplemental_table.setRowCount(0) + for component in components: + self._append_packmol_supplemental_component(component) + self.remove_packmol_supplemental_button.setEnabled( + self.packmol_supplemental_table.rowCount() > 0 + ) + + def _current_packmol_supplemental_components( + self, + ) -> list[PackmolSupplementalComponentSettings]: + components: list[PackmolSupplementalComponentSettings] = [] + for row in range(self.packmol_supplemental_table.rowCount()): + item = self.packmol_supplemental_table.item(row, 0) + if item is None: + continue + component = item.data(Qt.ItemDataRole.UserRole) + if isinstance(component, PackmolSupplementalComponentSettings): + components.append(component) + return components + def _current_packmol_planning_settings(self) -> PackmolPlanningSettings: return PackmolPlanningSettings( planning_mode=str( @@ -4815,6 +5448,9 @@ def _current_packmol_planning_settings(self) -> PackmolPlanningSettings: ), box_side_length_a=float(self.packmol_box_side_spin.value()), free_solvent_reference=self._selected_packmol_free_solvent_reference(), + supplemental_components=tuple( + self._current_packmol_supplemental_components() + ), ) def _current_packmol_setup_settings(self) -> PackmolSetupSettings: @@ -5783,6 +6419,8 @@ def _packmol_setup_summary_text( + str(state.rmcsetup_paths.packmol_setup_path) + "\nAudit report:\n" + str(state.rmcsetup_paths.packmol_audit_report_path) + + "\nBuild report:\n" + + str(state.rmcsetup_paths.packmol_build_report_path) ) if state.packmol_docker_link is not None: text += ( diff --git a/src/saxshell/mdtrajectory/__init__.py b/src/saxshell/mdtrajectory/__init__.py index c7ec7d9..554628d 100644 --- a/src/saxshell/mdtrajectory/__init__.py +++ b/src/saxshell/mdtrajectory/__init__.py @@ -1,6 +1,7 @@ """Headless and Qt interfaces for mdtrajectory workflows.""" from .workflow import ( + MDTrajectoryAssertionResult, MDTrajectoryExportResult, MDTrajectorySelectionResult, MDTrajectoryWorkflow, @@ -10,6 +11,7 @@ ) __all__ = [ + "MDTrajectoryAssertionResult", "MDTrajectoryExportResult", "MDTrajectorySelectionResult", "MDTrajectoryWorkflow", diff --git a/src/saxshell/mdtrajectory/cli.py b/src/saxshell/mdtrajectory/cli.py index 2ac9b2a..fef9b0d 100644 --- a/src/saxshell/mdtrajectory/cli.py +++ b/src/saxshell/mdtrajectory/cli.py @@ -7,6 +7,7 @@ from saxshell.version import __version__ from .workflow import ( + MDTrajectoryAssertionResult, MDTrajectoryExportResult, MDTrajectorySelectionResult, MDTrajectoryWorkflow, @@ -38,6 +39,7 @@ def build_parser() -> argparse.ArgumentParser: help="Inspect the trajectory and optionally the CP2K energy file.", ) _add_common_input_arguments(inspect_parser) + _add_restart_duplicate_argument(inspect_parser) inspect_parser.set_defaults(handler=_handle_inspect) suggest_parser = subparsers.add_parser( @@ -55,6 +57,7 @@ def build_parser() -> argparse.ArgumentParser: _add_common_input_arguments(preview_parser) _add_selection_arguments(preview_parser) _add_cutoff_resolution_arguments(preview_parser) + _add_restart_duplicate_argument(preview_parser) preview_parser.add_argument( "--output-dir", type=Path, @@ -70,6 +73,7 @@ def build_parser() -> argparse.ArgumentParser: _add_common_input_arguments(export_parser) _add_selection_arguments(export_parser) _add_cutoff_resolution_arguments(export_parser) + _add_restart_duplicate_argument(export_parser) export_parser.add_argument( "--output-dir", type=Path, @@ -78,6 +82,53 @@ def build_parser() -> argparse.ArgumentParser: ) export_parser.set_defaults(handler=_handle_export) + validate_parser = subparsers.add_parser( + "validate-export", + help=( + "Assert that exported XYZ frames map back to the source " + "trajectory indices and coordinates." + ), + ) + validate_parser.add_argument( + "trajectory", + type=Path, + help="Path to the source trajectory file (.xyz).", + ) + validate_parser.add_argument( + "frame_dir", + type=Path, + help="Directory containing exported frame_.xyz files.", + ) + validate_parser.add_argument( + "--coordinate-lines", + type=int, + default=3, + help="Number of leading coordinate lines to compare. Default: 3.", + ) + validate_parser.add_argument( + "--coord-tol", + type=float, + default=1.0e-9, + help="Absolute coordinate comparison tolerance. Default: 1e-9.", + ) + validate_parser.add_argument( + "--expect-contiguous", + action="store_true", + help="Fail if exported filename indices have gaps within their range.", + ) + validate_parser.add_argument( + "--strict-source-duplicates", + action="store_true", + help="Fail if the source trajectory contains duplicate i = indices.", + ) + validate_parser.add_argument( + "--max-issues", + type=int, + default=20, + help="Maximum number of issue examples to print. Default: 20.", + ) + validate_parser.set_defaults(handler=_handle_validate_export) + return parser @@ -155,8 +206,11 @@ def _add_cutoff_analysis_arguments( parser.add_argument( "--window", type=int, - default=3, - help="Consecutive sample window used for the steady-state test.", + default=2, + help=( + "Consecutive sample window used for the steady-state test. " + "Default: 2." + ), ) @@ -182,6 +236,18 @@ def _add_cutoff_resolution_arguments( _add_cutoff_analysis_arguments(parser, required_target=False) +def _add_restart_duplicate_argument(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--include-restart-duplicates", + action="store_true", + help=( + "Include duplicate XYZ frames from overlapping simulation " + "restarts. By default, earlier overlap frames are skipped and " + "the later continuation frame is kept." + ), + ) + + def _handle_ui(_: argparse.Namespace) -> int: from PySide6.QtWidgets import QApplication @@ -206,6 +272,11 @@ def _build_workflow(args: argparse.Namespace) -> MDTrajectoryWorkflow: trajectory_file=args.trajectory, topology_file=getattr(args, "topology", None), energy_file=getattr(args, "energy_file", None), + include_restart_duplicates=getattr( + args, + "include_restart_duplicates", + False, + ), ) @@ -229,7 +300,7 @@ def _resolve_cli_cutoff( result = workflow.suggest_cutoff( temp_target_k=temp_target_k, temp_tol_k=getattr(args, "temp_tol_k", 1.0), - window=getattr(args, "window", 3), + window=getattr(args, "window", 2), ) cutoff_fs = result.cutoff_time_fs if cutoff_fs is None: @@ -252,6 +323,18 @@ def _handle_inspect(args: argparse.Namespace) -> int: f"File type: {summary['file_type']}", f"Frames: {summary['n_frames']}", ] + if "raw_frames" in summary: + include_duplicates = bool( + summary.get("include_restart_duplicates", False) + ) + duplicate_action = "included" if include_duplicates else "skipped" + lines.extend( + [ + f"Raw frames: {summary['raw_frames']}", + f"Duplicate source frames {duplicate_action}: " + f"{summary.get('duplicate_source_frames', 0)}", + ] + ) if workflow.topology_file is not None: lines.append(f"Topology file: {workflow.topology_file}") if workflow.energy_file is not None: @@ -324,6 +407,20 @@ def _handle_export(args: argparse.Namespace) -> int: return 0 +def _handle_validate_export(args: argparse.Namespace) -> int: + workflow = _build_workflow(args) + result = workflow.validate_export( + args.frame_dir, + coordinate_lines=args.coordinate_lines, + coordinate_tolerance=args.coord_tol, + expect_contiguous=args.expect_contiguous, + strict_source_duplicates=args.strict_source_duplicates, + max_issues=args.max_issues, + ) + print(_format_assertion_result(result)) + return 0 if result.passed else 1 + + def _format_selection_result(selection: MDTrajectorySelectionResult) -> str: preview = selection.preview lines = [ @@ -334,6 +431,8 @@ def _format_selection_result(selection: MDTrajectorySelectionResult) -> str: f"Stop: {preview.stop}", f"Stride: {preview.stride}", f"Time-tagged frames: {preview.time_metadata_frames}", + "Restart duplicate frames: " + f"{'included' if selection.include_restart_duplicates else 'skipped'}", ] if selection.applied_cutoff_fs is not None: lines.append(f"Applied cutoff: {selection.applied_cutoff_fs:.3f} fs") @@ -363,3 +462,62 @@ def _format_export_result(result: MDTrajectoryExportResult) -> str: lines.append(f"First file: {result.written_files[0]}") lines.append(f"Last file: {result.written_files[-1]}") return "\n".join(lines) + + +def _format_assertion_result(result: MDTrajectoryAssertionResult) -> str: + status = "passed" if result.passed else "failed" + lines = [ + f"Export validation {status}.", + f"Trajectory file: {result.trajectory_file}", + f"Frame directory: {result.frame_dir}", + f"Coordinate lines checked: {result.coordinate_lines}", + f"Coordinate tolerance: {result.coordinate_tolerance:g}", + f"Source raw frames: {result.source_raw_frames}", + f"Source unique indices: {result.source_unique_indices}", + f"Source frames without i index: {result.source_missing_indices}", + f"Source duplicate i indices: {result.source_duplicate_indices}", + "Source duplicate coordinate conflicts: " + f"{result.source_duplicate_conflicts}", + f"Exported XYZ files: {result.exported_files}", + f"Validated XYZ files: {result.validated_files}", + ] + if result.filename_index_min is not None: + lines.append( + "Filename index range: " + f"{result.filename_index_min} to {result.filename_index_max}" + ) + if result.header_index_min is not None: + lines.append( + "Header index range: " + f"{result.header_index_min} to {result.header_index_max}" + ) + if result.filename_header_offsets: + offsets = ", ".join( + f"{offset}: {count}" + for offset, count in result.filename_header_offsets.items() + ) + lines.append(f"Filename-header offsets: {offsets}") + else: + lines.append("Filename-header offsets: none") + + if result.issue_counts: + lines.append("Assertion failures:") + lines.extend( + f"- {kind}: {count}" for kind, count in result.issue_counts.items() + ) + else: + lines.append("Assertion failures: none") + + if result.strict_source_duplicates and result.source_duplicate_indices: + lines.append( + "Strict source duplicate check failed: " + f"{result.source_duplicate_indices} duplicate source frame(s)." + ) + + if result.issues: + lines.append("Issue examples:") + for issue in result.issues: + location = "" if issue.path is None else f" [{issue.path}]" + lines.append(f"- {issue.kind}{location}: {issue.message}") + + return "\n".join(lines) diff --git a/src/saxshell/mdtrajectory/frame/assertions.py b/src/saxshell/mdtrajectory/frame/assertions.py new file mode 100644 index 0000000..f094c37 --- /dev/null +++ b/src/saxshell/mdtrajectory/frame/assertions.py @@ -0,0 +1,569 @@ +from __future__ import annotations + +import math +import re +from collections import Counter +from dataclasses import dataclass +from pathlib import Path + +FRAME_FILENAME_PATTERN = re.compile(r"^frame_(\d+)\.xyz$") +FRAME_INDEX_PATTERN = re.compile( + r"(?:^|[\s,;])i\s*=\s*(\d+)(?:\b|[\s,;])", + re.IGNORECASE, +) + + +@dataclass(frozen=True, slots=True) +class CoordinateLineSignature: + label: str + x: float + y: float + z: float + + +@dataclass(frozen=True, slots=True) +class XYZFrameSignature: + frame_index: int | None + atom_count: int + coordinates: tuple[CoordinateLineSignature, ...] + + +@dataclass(frozen=True, slots=True) +class MDTrajectoryAssertionIssue: + kind: str + message: str + path: str | None = None + frame_index: int | None = None + + +@dataclass(slots=True) +class MDTrajectoryAssertionResult: + trajectory_file: Path + frame_dir: Path + coordinate_lines: int + coordinate_tolerance: float + source_raw_frames: int + source_unique_indices: int + source_missing_indices: int + source_duplicate_indices: int + source_duplicate_conflicts: int + exported_files: int + validated_files: int + filename_index_min: int | None + filename_index_max: int | None + header_index_min: int | None + header_index_max: int | None + filename_header_offsets: dict[int, int] + issue_counts: dict[str, int] + issues: list[MDTrajectoryAssertionIssue] + strict_source_duplicates: bool + + @property + def passed(self) -> bool: + if any(count > 0 for count in self.issue_counts.values()): + return False + if self.strict_source_duplicates and self.source_duplicate_indices: + return False + return True + + @property + def failure_count(self) -> int: + failures = sum(self.issue_counts.values()) + if self.strict_source_duplicates: + failures += self.source_duplicate_indices + return failures + + def to_dict(self) -> dict[str, object]: + return { + "passed": self.passed, + "failure_count": self.failure_count, + "trajectory_file": str(self.trajectory_file), + "frame_dir": str(self.frame_dir), + "coordinate_lines": self.coordinate_lines, + "coordinate_tolerance": self.coordinate_tolerance, + "source_raw_frames": self.source_raw_frames, + "source_unique_indices": self.source_unique_indices, + "source_missing_indices": self.source_missing_indices, + "source_duplicate_indices": self.source_duplicate_indices, + "source_duplicate_conflicts": self.source_duplicate_conflicts, + "exported_files": self.exported_files, + "validated_files": self.validated_files, + "filename_index_min": self.filename_index_min, + "filename_index_max": self.filename_index_max, + "header_index_min": self.header_index_min, + "header_index_max": self.header_index_max, + "filename_header_offsets": dict(self.filename_header_offsets), + "issue_counts": dict(self.issue_counts), + "issues": [ + { + "kind": issue.kind, + "message": issue.message, + "path": issue.path, + "frame_index": issue.frame_index, + } + for issue in self.issues + ], + "strict_source_duplicates": self.strict_source_duplicates, + } + + +def validate_xyz_export_against_source( + trajectory_file: str | Path, + frame_dir: str | Path, + *, + coordinate_lines: int = 3, + coordinate_tolerance: float = 1.0e-9, + expect_contiguous: bool = False, + strict_source_duplicates: bool = False, + max_issues: int = 20, +) -> MDTrajectoryAssertionResult: + """Validate exported XYZ frames against their source trajectory. + + The export filename index must match the CP2K ``i =`` header index, + output header indices must be unique, and the first coordinate lines + must match the source trajectory frame with the same index. + """ + if coordinate_lines <= 0: + raise ValueError("coordinate_lines must be a positive integer.") + if coordinate_tolerance < 0: + raise ValueError("coordinate_tolerance must be non-negative.") + if max_issues < 0: + raise ValueError("max_issues must be non-negative.") + + trajectory_path = Path(trajectory_file) + export_path = Path(frame_dir) + if not trajectory_path.exists(): + raise FileNotFoundError( + f"Trajectory file not found: {trajectory_path}" + ) + if not export_path.is_dir(): + raise NotADirectoryError(f"Frame directory not found: {export_path}") + + source_by_index: dict[int, XYZFrameSignature] = {} + source_raw_frames = 0 + source_missing_indices = 0 + source_duplicate_indices = 0 + source_duplicate_conflicts = 0 + issue_counts: Counter[str] = Counter() + issues: list[MDTrajectoryAssertionIssue] = [] + + def add_issue( + kind: str, + message: str, + *, + path: Path | None = None, + frame_index: int | None = None, + count: int = 1, + ) -> None: + issue_counts[kind] += count + if len(issues) >= max_issues: + return + issues.append( + MDTrajectoryAssertionIssue( + kind=kind, + message=message, + path=None if path is None else str(path), + frame_index=frame_index, + ) + ) + + for source_frame in _iter_xyz_frame_signatures( + trajectory_path, + coordinate_lines=coordinate_lines, + ): + source_raw_frames += 1 + if source_frame.frame_index is None: + source_missing_indices += 1 + continue + existing = source_by_index.get(source_frame.frame_index) + if existing is not None: + source_duplicate_indices += 1 + if not _signatures_match( + existing, + source_frame, + coordinate_tolerance=coordinate_tolerance, + ): + source_duplicate_conflicts += 1 + source_by_index[source_frame.frame_index] = source_frame + continue + source_by_index[source_frame.frame_index] = source_frame + + exported_files = 0 + validated_files = 0 + filename_indices: list[int] = [] + header_indices: list[int] = [] + filename_index_counts: Counter[int] = Counter() + header_index_counts: Counter[int] = Counter() + offset_counts: Counter[int] = Counter() + + frame_paths = sorted( + export_path.glob("*.xyz"), + key=lambda path: _frame_file_sort_key(path), + ) + if not frame_paths: + add_issue( + "no_exported_xyz_files", + f"No exported XYZ files were found in {export_path}.", + path=export_path, + ) + for frame_path in frame_paths: + exported_files += 1 + name_match = FRAME_FILENAME_PATTERN.match(frame_path.name) + if name_match is None: + add_issue( + "invalid_export_filename", + f"Expected frame_.xyz filename, got {frame_path.name}.", + path=frame_path, + ) + continue + + filename_index = int(name_match.group(1)) + filename_indices.append(filename_index) + filename_index_counts[filename_index] += 1 + + try: + export_frame = _read_single_xyz_frame_signature( + frame_path, + coordinate_lines=coordinate_lines, + ) + except ValueError as exc: + add_issue( + "invalid_export_xyz", + str(exc), + path=frame_path, + frame_index=filename_index, + ) + continue + + validated_files += 1 + if export_frame.frame_index is None: + add_issue( + "missing_export_header_index", + f"{frame_path.name} does not include a CP2K i = index.", + path=frame_path, + frame_index=filename_index, + ) + else: + header_index = export_frame.frame_index + header_indices.append(header_index) + header_index_counts[header_index] += 1 + offset = filename_index - header_index + offset_counts[offset] += 1 + if offset != 0: + add_issue( + "filename_header_offset", + ( + f"{frame_path.name} filename index is " + f"{filename_index}, but header reports i = " + f"{header_index}." + ), + path=frame_path, + frame_index=filename_index, + ) + + source_frame = source_by_index.get(filename_index) + if source_frame is None: + add_issue( + "missing_source_index", + ( + f"{frame_path.name} references index {filename_index}, " + "which is not present in the source trajectory." + ), + path=frame_path, + frame_index=filename_index, + ) + continue + + if export_frame.atom_count != source_frame.atom_count: + add_issue( + "atom_count_mismatch", + ( + f"{frame_path.name} atom count is " + f"{export_frame.atom_count}, but source frame " + f"{filename_index} has {source_frame.atom_count}." + ), + path=frame_path, + frame_index=filename_index, + ) + + if not _coordinates_match( + source_frame.coordinates, + export_frame.coordinates, + coordinate_tolerance=coordinate_tolerance, + ): + add_issue( + "coordinate_mismatch", + ( + f"{frame_path.name} first {coordinate_lines} coordinate " + f"line(s) do not match source frame {filename_index}." + ), + path=frame_path, + frame_index=filename_index, + ) + + for duplicate_index, count in sorted(filename_index_counts.items()): + if count <= 1: + continue + add_issue( + "duplicate_export_filename_index", + ( + f"Export contains {count} filenames resolving to frame " + f"index {duplicate_index}." + ), + frame_index=duplicate_index, + count=count - 1, + ) + + for duplicate_index, count in sorted(header_index_counts.items()): + if count <= 1: + continue + add_issue( + "duplicate_export_header_index", + ( + f"Export contains {count} frames whose headers report " + f"i = {duplicate_index}." + ), + frame_index=duplicate_index, + count=count - 1, + ) + + if expect_contiguous and filename_indices: + filename_index_set = set(filename_indices) + start = min(filename_index_set) + stop = max(filename_index_set) + missing_indices = [ + index + for index in range(start, stop + 1) + if index not in filename_index_set + ] + if missing_indices: + add_issue( + "missing_contiguous_export_index", + ( + f"Export filename range {start}-{stop} is missing " + f"{len(missing_indices)} index value(s); first missing " + f"index is {missing_indices[0]}." + ), + frame_index=missing_indices[0], + count=len(missing_indices), + ) + + return MDTrajectoryAssertionResult( + trajectory_file=trajectory_path, + frame_dir=export_path, + coordinate_lines=coordinate_lines, + coordinate_tolerance=coordinate_tolerance, + source_raw_frames=source_raw_frames, + source_unique_indices=len(source_by_index), + source_missing_indices=source_missing_indices, + source_duplicate_indices=source_duplicate_indices, + source_duplicate_conflicts=source_duplicate_conflicts, + exported_files=exported_files, + validated_files=validated_files, + filename_index_min=min(filename_indices) if filename_indices else None, + filename_index_max=max(filename_indices) if filename_indices else None, + header_index_min=min(header_indices) if header_indices else None, + header_index_max=max(header_indices) if header_indices else None, + filename_header_offsets=dict(sorted(offset_counts.items())), + issue_counts=dict(sorted(issue_counts.items())), + issues=issues, + strict_source_duplicates=strict_source_duplicates, + ) + + +def _iter_xyz_frame_signatures( + path: Path, + *, + coordinate_lines: int, +): + with path.open("r", encoding="utf-8", errors="replace") as handle: + while True: + line = handle.readline() + if not line: + break + stripped = line.strip() + if not stripped: + continue + + if stripped.startswith("frame") or stripped.startswith("NSTEP="): + header = line + atom_count_line = handle.readline() + if not atom_count_line: + break + atom_count_text = atom_count_line.strip() + if not atom_count_text.isdigit(): + continue + atom_count = int(atom_count_text) + atom_lines = _read_atom_lines(handle, atom_count) + if atom_lines is None: + break + yield _frame_signature( + header, + atom_count, + atom_lines, + coordinate_lines=coordinate_lines, + ) + continue + + if not stripped.isdigit(): + continue + + atom_count = int(stripped) + header = handle.readline() + if not header: + break + atom_lines = _read_atom_lines(handle, atom_count) + if atom_lines is None: + break + yield _frame_signature( + header, + atom_count, + atom_lines, + coordinate_lines=coordinate_lines, + ) + + +def _read_single_xyz_frame_signature( + path: Path, + *, + coordinate_lines: int, +) -> XYZFrameSignature: + with path.open("r", encoding="utf-8", errors="replace") as handle: + atom_count_line = handle.readline() + if not atom_count_line: + raise ValueError(f"{path.name} is empty.") + atom_count_text = atom_count_line.strip() + if not atom_count_text.isdigit(): + raise ValueError( + f"{path.name} does not start with an XYZ atom count." + ) + atom_count = int(atom_count_text) + header = handle.readline() + if not header: + raise ValueError(f"{path.name} is missing its XYZ comment line.") + atom_lines = _read_atom_lines(handle, atom_count) + if atom_lines is None: + raise ValueError( + f"{path.name} ended before {atom_count} atom line(s)." + ) + return _frame_signature( + header, + atom_count, + atom_lines, + coordinate_lines=coordinate_lines, + ) + + +def _read_atom_lines(handle, atom_count: int) -> list[str] | None: + atom_lines: list[str] = [] + for _ in range(atom_count): + atom_line = handle.readline() + if not atom_line: + return None + atom_lines.append(atom_line) + return atom_lines + + +def _frame_signature( + header: str, + atom_count: int, + atom_lines: list[str], + *, + coordinate_lines: int, +) -> XYZFrameSignature: + return XYZFrameSignature( + frame_index=_parse_frame_index(header), + atom_count=atom_count, + coordinates=tuple( + coordinate + for coordinate in ( + _coordinate_signature(line) + for line in atom_lines[:coordinate_lines] + ) + if coordinate is not None + ), + ) + + +def _parse_frame_index(header: str) -> int | None: + match = FRAME_INDEX_PATTERN.search(header.strip()) + if match is None: + return None + try: + return int(match.group(1)) + except ValueError: + return None + + +def _coordinate_signature(line: str) -> CoordinateLineSignature | None: + parts = line.split() + if len(parts) < 4: + return None + try: + return CoordinateLineSignature( + label=_normalize_atom_label(parts[0]), + x=float(parts[1]), + y=float(parts[2]), + z=float(parts[3]), + ) + except ValueError: + return None + + +def _normalize_atom_label(label: str) -> str: + return "".join(char for char in label if not char.isdigit()).capitalize() + + +def _signatures_match( + left: XYZFrameSignature, + right: XYZFrameSignature, + *, + coordinate_tolerance: float, +) -> bool: + return left.atom_count == right.atom_count and _coordinates_match( + left.coordinates, + right.coordinates, + coordinate_tolerance=coordinate_tolerance, + ) + + +def _coordinates_match( + left: tuple[CoordinateLineSignature, ...], + right: tuple[CoordinateLineSignature, ...], + *, + coordinate_tolerance: float, +) -> bool: + if len(left) != len(right): + return False + for left_line, right_line in zip(left, right, strict=True): + if left_line.label != right_line.label: + return False + if not ( + math.isclose( + left_line.x, + right_line.x, + rel_tol=0.0, + abs_tol=coordinate_tolerance, + ) + and math.isclose( + left_line.y, + right_line.y, + rel_tol=0.0, + abs_tol=coordinate_tolerance, + ) + and math.isclose( + left_line.z, + right_line.z, + rel_tol=0.0, + abs_tol=coordinate_tolerance, + ) + ): + return False + return True + + +def _frame_file_sort_key(path: Path) -> tuple[int, int | str]: + match = FRAME_FILENAME_PATTERN.match(path.name) + if match is None: + return (1, path.name) + return (0, int(match.group(1))) diff --git a/src/saxshell/mdtrajectory/frame/cp2k_backend.py b/src/saxshell/mdtrajectory/frame/cp2k_backend.py index 8824176..d7e944f 100644 --- a/src/saxshell/mdtrajectory/frame/cp2k_backend.py +++ b/src/saxshell/mdtrajectory/frame/cp2k_backend.py @@ -16,6 +16,10 @@ re.IGNORECASE, ), ) +FRAME_INDEX_PATTERN = re.compile( + r"(?:^|[\s,;])i\s*=\s*(\d+)(?:\b|[\s,;])", + re.IGNORECASE, +) class CP2KTrajectoryBackend(TrajectoryBackend): @@ -32,6 +36,8 @@ def __init__( self, input_file: str | Path, topology_file: str | Path | None = None, + *, + include_restart_duplicates: bool = False, ) -> None: super().__init__(input_file=input_file, topology_file=topology_file) suffix = self.input_file.suffix.lower() @@ -40,14 +46,27 @@ def __init__( "CP2KTrajectoryBackend supports only .xyz and .pdb files." ) self.file_type = suffix.lstrip(".") + self.include_restart_duplicates = bool(include_restart_duplicates) + self._raw_frame_count: int | None = None + self._duplicate_source_frame_count: int = 0 def inspect(self) -> dict[str, object]: frame_metadata = self.load_frame_metadata() - return { + summary: dict[str, object] = { "input_file": str(self.input_file), "file_type": self.file_type, "n_frames": len(frame_metadata), + "include_restart_duplicates": self.include_restart_duplicates, } + if self._raw_frame_count is not None and ( + self._raw_frame_count != len(frame_metadata) + or self._duplicate_source_frame_count + ): + summary["raw_frames"] = self._raw_frame_count + summary["duplicate_source_frames"] = ( + self._duplicate_source_frame_count + ) + return summary def iter_frame_metadata(self) -> list[FrameMetadata]: if self.file_type == "xyz": @@ -107,8 +126,11 @@ def _estimate_frame_count_xyz(self) -> int: return count def _parse_xyz_frame_metadata(self) -> list[FrameMetadata]: - frames: list[FrameMetadata] = [] - frame_idx = 0 + frames_by_index: dict[int, FrameMetadata] = {} + frames_with_duplicates: list[FrameMetadata] = [] + seen_source_indices: set[int] = set() + raw_frame_count = 0 + duplicate_source_frame_count = 0 with self.input_file.open("r") as handle: while True: @@ -129,13 +151,32 @@ def _parse_xyz_frame_metadata(self) -> list[FrameMetadata]: atom_count_text = atom_count_line.strip() if not atom_count_text.isdigit(): continue - frames.append( - FrameMetadata( - frame_index=frame_idx, - time_fs=self._parse_time_from_header(line), - ) + source_index = self._parse_frame_index_from_metadata(line) + fallback_index = ( + raw_frame_count + if self.include_restart_duplicates + else len(frames_by_index) ) - frame_idx += 1 + frame_index = self._resolve_frame_index( + line, + fallback_index=fallback_index, + ) + raw_frame_count += 1 + if ( + source_index is not None + and source_index in seen_source_indices + ): + duplicate_source_frame_count += 1 + if source_index is not None: + seen_source_indices.add(source_index) + frame = FrameMetadata( + frame_index=frame_index, + time_fs=self._parse_time_from_header(line), + ) + if self.include_restart_duplicates: + frames_with_duplicates.append(frame) + else: + frames_by_index[frame_index] = frame for _ in range(int(atom_count_text)): if not handle.readline(): break @@ -146,106 +187,197 @@ def _parse_xyz_frame_metadata(self) -> list[FrameMetadata]: atom_count = int(stripped) comment = handle.readline() + if not comment: + break time_val = ( None if not comment else self._parse_time_from_metadata(comment) ) - frames.append( - FrameMetadata( - frame_index=frame_idx, - time_fs=time_val, - ) + source_index = self._parse_frame_index_from_metadata(comment) + fallback_index = ( + raw_frame_count + if self.include_restart_duplicates + else len(frames_by_index) ) - frame_idx += 1 - if not comment: - break + frame_index = self._resolve_frame_index( + comment, + fallback_index=fallback_index, + ) + raw_frame_count += 1 + if ( + source_index is not None + and source_index in seen_source_indices + ): + duplicate_source_frame_count += 1 + if source_index is not None: + seen_source_indices.add(source_index) + frame = FrameMetadata( + frame_index=frame_index, + time_fs=time_val, + ) + if self.include_restart_duplicates: + frames_with_duplicates.append(frame) + else: + frames_by_index[frame_index] = frame for _ in range(atom_count): if not handle.readline(): break - return frames + self._raw_frame_count = raw_frame_count + self._duplicate_source_frame_count = duplicate_source_frame_count + if self.include_restart_duplicates: + return frames_with_duplicates + return [ + frames_by_index[frame_index] + for frame_index in sorted(frames_by_index) + ] def _parse_xyz_frames(self) -> list[FrameRecord]: - lines = self.input_file.read_text().splitlines(keepends=True) - frames: list[FrameRecord] = [] - frame_idx = 0 - atom_count: int | None = None - buffer: list[str] = [] - time_val: float | None = None + frames_by_index: dict[int, FrameRecord] = {} + frames_with_duplicates: list[FrameRecord] = [] + seen_source_indices: set[int] = set() + raw_frame_count = 0 + duplicate_source_frame_count = 0 - i = 0 - while i < len(lines): - s = lines[i].strip() + with self.input_file.open("r") as handle: + while True: + line = handle.readline() + if not line: + break - is_metadata_style = ( - s.isdigit() - and i + 1 < len(lines) - and lines[i + 1].strip().startswith("i =") - ) + stripped = line.strip() + if not stripped: + continue - if is_metadata_style: - if buffer and atom_count is not None: - frames.append( - FrameRecord( - frame_index=frame_idx, - file_type="xyz", - atom_count=atom_count, - lines=buffer.copy(), - time_fs=time_val, - ) + if stripped.startswith("frame") or stripped.startswith( + "NSTEP=" + ): + atom_count_line = handle.readline() + if not atom_count_line: + break + atom_count_text = atom_count_line.strip() + if not atom_count_text.isdigit(): + continue + atom_count = int(atom_count_text) + atom_lines = self._read_xyz_atom_lines( + handle, + atom_count, ) - frame_idx += 1 - - atom_count = int(s) - comment = lines[i + 1] - time_val = self._parse_time_from_metadata(comment) - buffer = [comment] - i += 2 - continue - - if s.startswith("frame") or s.startswith("NSTEP="): - if buffer and atom_count is not None: - frames.append( - FrameRecord( - frame_index=frame_idx, - file_type="xyz", - atom_count=atom_count, - lines=buffer.copy(), - time_fs=time_val, - ) + if atom_lines is None: + break + source_index = self._parse_frame_index_from_metadata(line) + fallback_index = ( + raw_frame_count + if self.include_restart_duplicates + else len(frames_by_index) ) - frame_idx += 1 - atom_count = None - time_val = self._parse_time_from_header(lines[i]) - buffer = [lines[i]] - i += 1 - continue - - if s.isdigit(): - atom_count = int(s) - if not buffer: - time_val = None - i += 1 - continue - - if atom_count is not None: - buffer.append(lines[i]) - - i += 1 - - if buffer and atom_count is not None: - frames.append( - FrameRecord( - frame_index=frame_idx, + frame_index = self._resolve_frame_index( + line, + fallback_index=fallback_index, + ) + raw_frame_count += 1 + if ( + source_index is not None + and source_index in seen_source_indices + ): + duplicate_source_frame_count += 1 + if source_index is not None: + seen_source_indices.add(source_index) + frame = FrameRecord( + frame_index=frame_index, + file_type="xyz", + atom_count=atom_count, + lines=[line, *atom_lines], + time_fs=self._parse_time_from_header(line), + ) + if self.include_restart_duplicates: + frames_with_duplicates.append(frame) + else: + frames_by_index[frame_index] = frame + continue + + if not stripped.isdigit(): + continue + + atom_count = int(stripped) + comment = handle.readline() + if not comment: + break + atom_lines = self._read_xyz_atom_lines(handle, atom_count) + if atom_lines is None: + break + source_index = self._parse_frame_index_from_metadata(comment) + fallback_index = ( + raw_frame_count + if self.include_restart_duplicates + else len(frames_by_index) + ) + frame_index = self._resolve_frame_index( + comment, + fallback_index=fallback_index, + ) + raw_frame_count += 1 + if ( + source_index is not None + and source_index in seen_source_indices + ): + duplicate_source_frame_count += 1 + if source_index is not None: + seen_source_indices.add(source_index) + frame = FrameRecord( + frame_index=frame_index, file_type="xyz", atom_count=atom_count, - lines=buffer.copy(), - time_fs=time_val, + lines=[comment, *atom_lines], + time_fs=self._parse_time_from_metadata(comment), ) - ) - - return frames + if self.include_restart_duplicates: + frames_with_duplicates.append(frame) + else: + frames_by_index[frame_index] = frame + + self._raw_frame_count = raw_frame_count + self._duplicate_source_frame_count = duplicate_source_frame_count + if self.include_restart_duplicates: + return frames_with_duplicates + return [ + frames_by_index[frame_index] + for frame_index in sorted(frames_by_index) + ] + + def _read_xyz_atom_lines( + self, + handle, + atom_count: int, + ) -> list[str] | None: + atom_lines: list[str] = [] + for _ in range(atom_count): + atom_line = handle.readline() + if not atom_line: + return None + atom_lines.append(atom_line) + return atom_lines + + def _resolve_frame_index( + self, + header: str, + *, + fallback_index: int, + ) -> int: + source_index = self._parse_frame_index_from_metadata(header) + if source_index is None: + return fallback_index + return source_index + + def _parse_frame_index_from_metadata(self, line: str) -> int | None: + match = FRAME_INDEX_PATTERN.search(line.strip()) + if match is None: + return None + try: + return int(match.group(1)) + except ValueError: + return None def _parse_time_from_metadata(self, line: str) -> float | None: return self._parse_time_from_header(line) diff --git a/src/saxshell/mdtrajectory/frame/cutoff_analysis.py b/src/saxshell/mdtrajectory/frame/cutoff_analysis.py index 4048f64..296df23 100644 --- a/src/saxshell/mdtrajectory/frame/cutoff_analysis.py +++ b/src/saxshell/mdtrajectory/frame/cutoff_analysis.py @@ -27,7 +27,7 @@ def suggest_steady_state_cutoff( self, temp_target_k: float, temp_tol_k: float = 1.0, - window: int = 10, + window: int = 2, kinetic_rel_std_max: float = 1.0e-3, potential_rel_std_max: float = 1.0e-3, ) -> SteadyStateResult: diff --git a/src/saxshell/mdtrajectory/frame/exporters.py b/src/saxshell/mdtrajectory/frame/exporters.py index 3c991c1..5ce9782 100644 --- a/src/saxshell/mdtrajectory/frame/exporters.py +++ b/src/saxshell/mdtrajectory/frame/exporters.py @@ -1,17 +1,24 @@ from __future__ import annotations +import re +from collections import Counter from pathlib import Path from typing import Callable from .base import FrameRecord ExportProgressCallback = Callable[[int, int, str], None] +XYZ_FRAME_INDEX_PATTERN = re.compile( + r"(?:^|[\s,;])i\s*=\s*(\d+)(?:\b|[\s,;])", + re.IGNORECASE, +) def export_xyz_frames( frames: list[FrameRecord], output_dir: str | Path, *, + allow_duplicate_frame_indices: bool = False, progress_callback: ExportProgressCallback | None = None, ) -> list[Path]: """Write frame records as XYZ files.""" @@ -20,9 +27,26 @@ def export_xyz_frames( written_files: list[Path] = [] xyz_frames = [frame for frame in frames if frame.file_type == "xyz"] total_frames = len(xyz_frames) + output_paths: set[Path] = set() + frame_index_counts = Counter(frame.frame_index for frame in xyz_frames) + seen_frame_indices: Counter[int] = Counter() for index, frame in enumerate(xyz_frames, start=1): - file_path = output_path / f"frame_{frame.frame_index:04d}.xyz" + _validate_xyz_frame_identity(frame) + seen_frame_indices[frame.frame_index] += 1 + file_name = _xyz_output_filename( + frame.frame_index, + occurrence=seen_frame_indices[frame.frame_index], + total=frame_index_counts[frame.frame_index], + allow_duplicate_frame_indices=allow_duplicate_frame_indices, + ) + file_path = output_path / file_name + if file_path in output_paths: + raise ValueError( + "Multiple XYZ frames resolve to the same output file: " + f"{file_path.name}" + ) + output_paths.add(file_path) with file_path.open("w") as handle: handle.write(f"{frame.atom_count}\n") handle.write(frame.lines[0]) @@ -45,6 +69,22 @@ def export_xyz_frames( return written_files +def _xyz_output_filename( + frame_index: int, + *, + occurrence: int, + total: int, + allow_duplicate_frame_indices: bool, +) -> str: + if total <= 1: + return f"frame_{frame_index:04d}.xyz" + if not allow_duplicate_frame_indices: + return f"frame_{frame_index:04d}.xyz" + if occurrence == total: + return f"frame_{frame_index:04d}.xyz" + return f"frame_{frame_index:04d}_duplicate{occurrence:04d}.xyz" + + def export_pdb_frames( frames: list[FrameRecord], output_dir: str | Path, @@ -57,9 +97,16 @@ def export_pdb_frames( written_files: list[Path] = [] pdb_frames = [frame for frame in frames if frame.file_type == "pdb"] total_frames = len(pdb_frames) + output_paths: set[Path] = set() for index, frame in enumerate(pdb_frames, start=1): file_path = output_path / f"frame_{frame.frame_index:04d}.pdb" + if file_path in output_paths: + raise ValueError( + "Multiple PDB frames resolve to the same output file: " + f"{file_path.name}" + ) + output_paths.add(file_path) with file_path.open("w") as handle: handle.writelines(frame.lines) written_files.append(file_path) @@ -71,3 +118,28 @@ def export_pdb_frames( ) return written_files + + +def _validate_xyz_frame_identity(frame: FrameRecord) -> None: + """Reject CP2K XYZ records whose header index and output index + differ.""" + if not frame.lines: + return + source_index = _parse_xyz_source_frame_index(frame.lines[0]) + if source_index is None or source_index == frame.frame_index: + return + raise ValueError( + "XYZ frame identity mismatch: header reports " + f"i = {source_index}, but the export frame index is " + f"{frame.frame_index}." + ) + + +def _parse_xyz_source_frame_index(header: str) -> int | None: + match = XYZ_FRAME_INDEX_PATTERN.search(header.strip()) + if match is None: + return None + try: + return int(match.group(1)) + except ValueError: + return None diff --git a/src/saxshell/mdtrajectory/frame/manager.py b/src/saxshell/mdtrajectory/frame/manager.py index 58feba8..2611745 100644 --- a/src/saxshell/mdtrajectory/frame/manager.py +++ b/src/saxshell/mdtrajectory/frame/manager.py @@ -42,12 +42,15 @@ def __init__( input_file: str | Path, topology_file: str | Path | None = None, backend: str = "auto", + *, + include_restart_duplicates: bool = False, ) -> None: self.input_file = Path(input_file) self.topology_file = ( Path(topology_file) if topology_file is not None else None ) self.backend_name = backend + self.include_restart_duplicates = bool(include_restart_duplicates) self.backend = self._build_backend() self.frames: list[FrameRecord] | None = None @@ -58,6 +61,7 @@ def _build_backend(self): return CP2KTrajectoryBackend( input_file=self.input_file, topology_file=self.topology_file, + include_restart_duplicates=self.include_restart_duplicates, ) if self.backend_name == "auto": @@ -65,12 +69,26 @@ def _build_backend(self): return CP2KTrajectoryBackend( input_file=self.input_file, topology_file=self.topology_file, + include_restart_duplicates=( + self.include_restart_duplicates + ), ) raise ValueError( "Only CP2K .xyz/.pdb backend is implemented in this version." ) + def set_include_restart_duplicates( + self, + include_restart_duplicates: bool, + ) -> None: + include_restart_duplicates = bool(include_restart_duplicates) + if self.include_restart_duplicates == include_restart_duplicates: + return + self.include_restart_duplicates = include_restart_duplicates + self.backend = self._build_backend() + self.frames = None + def inspect(self) -> dict[str, object]: return self.backend.inspect() @@ -189,6 +207,9 @@ def export_frames( return export_xyz_frames( frames, output_dir=output_dir, + allow_duplicate_frame_indices=( + self.include_restart_duplicates + ), progress_callback=progress_callback, ) @@ -224,6 +245,8 @@ def __init__( self, trajectory_file: str | Path, energy_file: str | Path | None = None, + *, + include_restart_duplicates: bool = False, ) -> None: self.trajectory_file = Path(trajectory_file) self.energy_file = ( @@ -232,6 +255,7 @@ def __init__( self.trajectory = TrajectoryManager( input_file=self.trajectory_file, backend="cp2k", + include_restart_duplicates=include_restart_duplicates, ) self.energy_data: CP2KEnergyData | None = None self.steady_state: SteadyStateResult | None = None @@ -249,7 +273,7 @@ def suggest_cutoff( self, temp_target_k: float, temp_tol_k: float = 1.0, - window: int = 10, + window: int = 2, ) -> SteadyStateResult: if self.energy_data is None: self.load_energy() diff --git a/src/saxshell/mdtrajectory/ui/batch_queue_window.py b/src/saxshell/mdtrajectory/ui/batch_queue_window.py new file mode 100644 index 0000000..1dbeb21 --- /dev/null +++ b/src/saxshell/mdtrajectory/ui/batch_queue_window.py @@ -0,0 +1,1212 @@ +from __future__ import annotations + +import threading +import uuid +from dataclasses import dataclass, replace +from pathlib import Path + +from PySide6.QtCore import QObject, Qt, QThread, Signal, Slot +from PySide6.QtWidgets import ( + QAbstractItemView, + QApplication, + QCheckBox, + QDoubleSpinBox, + QFileDialog, + QFormLayout, + QFrame, + QHBoxLayout, + QLabel, + QLineEdit, + QListView, + QListWidget, + QListWidgetItem, + QMainWindow, + QMessageBox, + QProgressBar, + QPushButton, + QSizePolicy, + QTextEdit, + QToolButton, + QTreeView, + QVBoxLayout, + QWidget, +) + +from saxshell.mdtrajectory.workflow import MDTrajectoryWorkflow +from saxshell.saxs.project_manager import ( + SAXSProjectManager, + build_project_paths, +) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) + +DEFAULT_TIME_CUTOFF_FS = 1000.0 + + +def _new_item_id() -> str: + return uuid.uuid4().hex + + +def _optional_path(text: str) -> Path | None: + stripped = text.strip() + if not stripped: + return None + return Path(stripped).expanduser().resolve() + + +def _required_path(text: str, field_name: str) -> Path: + path = _optional_path(text) + if path is None: + raise ValueError(f"{field_name} is required.") + return path + + +def _required_existing_file(text: str, field_name: str) -> Path: + path = _required_path(text, field_name) + if not path.is_file(): + raise ValueError(f"{field_name} does not exist: {path}") + return path + + +def _required_project_dir(text: str) -> Path: + project_dir = _required_path(text, "Project folder") + project_file = build_project_paths(project_dir).project_file + if not project_file.is_file(): + raise ValueError(f"Project file does not exist: {project_file}") + return project_dir + + +def _project_reference_text(project_dir: Path | None) -> str: + if project_dir is None: + return "Project reference: choose a SAXSShell project folder." + project_file = build_project_paths(project_dir).project_file + if project_file.is_file(): + return f"Project reference: {project_file}" + return f"Project reference: no project file found at {project_file}" + + +def _dialog_start_dir(*candidates: str | Path | None) -> str: + for candidate in candidates: + if candidate is None: + continue + path = Path(candidate).expanduser() + if path.is_file(): + return str(path.parent) + if path.is_dir(): + return str(path) + return str(Path.home()) + + +def _choose_existing_directories( + parent: QWidget, + *, + title: str, + start_dir: str | Path, +) -> tuple[Path, ...]: + dialog = QFileDialog(parent, title, str(start_dir)) + dialog.setFileMode(QFileDialog.FileMode.Directory) + dialog.setOption(QFileDialog.Option.ShowDirsOnly, True) + dialog.setOption(QFileDialog.Option.DontUseNativeDialog, True) + for view in dialog.findChildren(QListView) + dialog.findChildren( + QTreeView + ): + view.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + if dialog.exec() != int(QFileDialog.DialogCode.Accepted): + return () + return tuple( + Path(path).expanduser().resolve() for path in dialog.selectedFiles() + ) + + +@dataclass(slots=True, frozen=True) +class MDTrajectoryBatchJob: + project_dir: Path + trajectory_file: Path + topology_file: Path | None + energy_file: Path + output_dir: Path | None = None + cutoff_fs: float = DEFAULT_TIME_CUTOFF_FS + include_restart_duplicates: bool = False + + +@dataclass(slots=True) +class MDTrajectoryBatchResult: + project_dir: Path + output_dir: Path + written_count: int + selected_frames: int + cutoff_fs: float + metadata_file: Path | None = None + include_restart_duplicates: bool = False + + +@dataclass(slots=True) +class MDTrajectoryBatchItem: + item_id: str + project_dir: Path | None = None + trajectory_file: Path | None = None + topology_file: Path | None = None + energy_file: Path | None = None + output_dir: Path | None = None + cutoff_fs: float = DEFAULT_TIME_CUTOFF_FS + include_restart_duplicates: bool = False + + def display_name(self) -> str: + if self.project_dir is not None: + return self.project_dir.name + if self.trajectory_file is not None: + return self.trajectory_file.name + return "New MD trajectory extraction" + + def to_job(self) -> MDTrajectoryBatchJob: + project_dir = _required_project_dir( + "" if self.project_dir is None else str(self.project_dir) + ) + trajectory_file = _required_existing_file( + "" if self.trajectory_file is None else str(self.trajectory_file), + "Trajectory file", + ) + topology_file = None + if self.topology_file is not None: + topology_file = _required_existing_file( + str(self.topology_file), + "Topology file", + ) + energy_file = _required_existing_file( + "" if self.energy_file is None else str(self.energy_file), + "Energy file", + ) + output_dir = None + if self.output_dir is not None: + output_dir = self.output_dir.expanduser().resolve() + if output_dir.exists() and not output_dir.is_dir(): + raise ValueError( + f"Output folder exists but is not a directory: {output_dir}" + ) + cutoff_fs = float(self.cutoff_fs) + if cutoff_fs < 0.0: + raise ValueError("Time cutoff must be zero or greater.") + return MDTrajectoryBatchJob( + project_dir=project_dir, + trajectory_file=trajectory_file, + topology_file=topology_file, + energy_file=energy_file, + output_dir=output_dir, + cutoff_fs=cutoff_fs, + include_restart_duplicates=self.include_restart_duplicates, + ) + + +def _queue_item_from_project_defaults( + project_dir: str | Path, + *, + item_id: str | None = None, +) -> MDTrajectoryBatchItem: + resolved_project_dir = Path(project_dir).expanduser().resolve() + item = MDTrajectoryBatchItem( + item_id=item_id or _new_item_id(), + project_dir=resolved_project_dir, + ) + try: + settings = SAXSProjectManager().load_project(resolved_project_dir) + except Exception: + return item + return replace( + item, + trajectory_file=settings.resolved_trajectory_file, + topology_file=settings.resolved_topology_file, + energy_file=settings.resolved_energy_file, + ) + + +class MDTrajectoryBatchItemWidget(QFrame): + settings_changed = Signal(str) + remove_requested = Signal(str) + duplicate_requested = Signal(str) + + def __init__( + self, + item: MDTrajectoryBatchItem, + *, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._item = item + self._loading = False + self._locked = False + self._selected = False + self._last_suggested_output_dir: Path | None = None + self._build_ui() + self._load_item(item) + self._set_settings_visible(False) + + @property + def item_id(self) -> str: + return self._item.item_id + + def item(self) -> MDTrajectoryBatchItem: + return self._item + + def collect_item(self) -> MDTrajectoryBatchItem: + self._item = MDTrajectoryBatchItem( + item_id=self._item.item_id, + project_dir=_optional_path(self.project_dir_edit.text()), + trajectory_file=_optional_path(self.trajectory_file_edit.text()), + topology_file=_optional_path(self.topology_file_edit.text()), + energy_file=_optional_path(self.energy_file_edit.text()), + output_dir=_optional_path(self.output_dir_edit.text()), + cutoff_fs=float(self.cutoff_spin.value()), + include_restart_duplicates=( + self.include_restart_duplicates_box.isChecked() + ), + ) + self._refresh_header() + self._refresh_project_reference() + return self._item + + def job(self) -> MDTrajectoryBatchJob: + return self.collect_item().to_job() + + def set_locked(self, locked: bool) -> None: + self._locked = bool(locked) + self.settings_group.setEnabled(not locked) + self.preview_button.setEnabled(not locked) + self.duplicate_button.setEnabled(not locked) + self.remove_button.setEnabled(not locked) + + def set_status(self, message: str) -> None: + self.status_label.setText(message) + + def set_progress(self, processed: int, total: int) -> None: + self.progress_bar.setRange(0, max(int(total), 1)) + self.progress_bar.setValue(max(int(processed), 0)) + + def set_selected(self, selected: bool) -> None: + self._selected = bool(selected) + self.header_frame.setProperty("selected", self._selected) + self.header_frame.setStyleSheet( + "QFrame#MDTrajectoryBatchItemHeader {" + + ( + "background-color: #dce8f7; " "border: 1px solid #8fb0d7;" + if self._selected + else "background-color: #f6f8fb; " "border: 1px solid #cfd7e3;" + ) + + "border-radius: 5px;}" + ) + + def preview_selection(self) -> None: + job = self.job() + workflow = MDTrajectoryWorkflow( + trajectory_file=job.trajectory_file, + topology_file=job.topology_file, + energy_file=job.energy_file, + include_restart_duplicates=job.include_restart_duplicates, + ) + current_output_dir = _optional_path(self.output_dir_edit.text()) + use_suggested_output_dir = ( + current_output_dir is None + or current_output_dir == self._last_suggested_output_dir + ) + selection = workflow.preview_selection( + use_cutoff=True, + cutoff_fs=job.cutoff_fs, + output_dir=( + None if use_suggested_output_dir else current_output_dir + ), + ) + if use_suggested_output_dir: + self._last_suggested_output_dir = selection.output_dir.resolve() + self.output_dir_edit.setText(str(selection.output_dir)) + job = replace(job, output_dir=selection.output_dir.resolve()) + preview = selection.preview + lines = [ + f"Frames selected: {preview.selected_frames} / " + f"{preview.total_frames}", + f"Output folder: {selection.output_dir}", + f"Applied cutoff: {job.cutoff_fs:g} fs", + "Restart duplicate frames: " + f"{'included' if job.include_restart_duplicates else 'skipped'}", + ] + if preview.first_frame_index is not None: + lines.append( + "Frame index range: " + f"{preview.first_frame_index} to {preview.last_frame_index}" + ) + if preview.first_time_fs is not None: + lines.append( + "Time range: " + f"{preview.first_time_fs:.3f} fs to " + f"{preview.last_time_fs:.3f} fs" + ) + self.preview_summary_label.setText("\n".join(lines)) + self.set_progress(0, max(preview.selected_frames, 1)) + self.set_status("Preview ready") + + def _build_ui(self) -> None: + self.setFrameShape(QFrame.Shape.StyledPanel) + self.setSizePolicy( + QSizePolicy.Policy.Expanding, + QSizePolicy.Policy.Fixed, + ) + root = QVBoxLayout(self) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + self.header_frame = QFrame() + self.header_frame.setObjectName("MDTrajectoryBatchItemHeader") + header = QHBoxLayout(self.header_frame) + header.setContentsMargins(8, 6, 8, 6) + header.setSpacing(8) + self.toggle_button = QToolButton() + self.toggle_button.setCheckable(True) + self.toggle_button.toggled.connect(self._set_settings_visible) + header.addWidget(self.toggle_button) + self.title_label = QLabel("New MD trajectory extraction") + self.title_label.setStyleSheet("font-weight: 600;") + header.addWidget(self.title_label, stretch=1) + self.status_label = QLabel("Ready") + self.status_label.setMinimumWidth(180) + header.addWidget(self.status_label) + self.preview_button = QPushButton("Preview") + self.preview_button.clicked.connect(self._preview_from_button) + header.addWidget(self.preview_button) + self.duplicate_button = QPushButton("Duplicate") + self.duplicate_button.clicked.connect( + lambda: self.duplicate_requested.emit(self.item_id) + ) + header.addWidget(self.duplicate_button) + self.remove_button = QPushButton("Remove") + self.remove_button.clicked.connect( + lambda: self.remove_requested.emit(self.item_id) + ) + header.addWidget(self.remove_button) + root.addWidget(self.header_frame) + self.set_selected(False) + + self.progress_bar = QProgressBar() + self.progress_bar.setRange(0, 1) + self.progress_bar.setValue(0) + self.progress_bar.setFormat("%v / %m frames") + root.addWidget(self.progress_bar) + + self.settings_group = QFrame() + self.settings_group.setFrameShape(QFrame.Shape.StyledPanel) + root.addWidget(self.settings_group) + form = QFormLayout(self.settings_group) + + project_row = QWidget() + project_layout = QHBoxLayout(project_row) + project_layout.setContentsMargins(0, 0, 0, 0) + self.project_dir_edit = QLineEdit() + self.project_dir_edit.editingFinished.connect(self._on_editor_changed) + project_layout.addWidget(self.project_dir_edit, stretch=1) + project_button = QPushButton("Browse...") + project_button.clicked.connect(self._choose_project_dir) + project_layout.addWidget(project_button) + form.addRow("Project folder", project_row) + + self.project_reference_label = QLabel() + self.project_reference_label.setWordWrap(True) + self.project_reference_label.setFrameShape(QFrame.Shape.StyledPanel) + form.addRow("", self.project_reference_label) + + trajectory_row = QWidget() + trajectory_layout = QHBoxLayout(trajectory_row) + trajectory_layout.setContentsMargins(0, 0, 0, 0) + self.trajectory_file_edit = QLineEdit() + self.trajectory_file_edit.editingFinished.connect( + self._on_editor_changed + ) + trajectory_layout.addWidget(self.trajectory_file_edit, stretch=1) + trajectory_button = QPushButton("Browse...") + trajectory_button.clicked.connect(self._choose_trajectory_file) + trajectory_layout.addWidget(trajectory_button) + form.addRow("Trajectory file", trajectory_row) + + topology_row = QWidget() + topology_layout = QHBoxLayout(topology_row) + topology_layout.setContentsMargins(0, 0, 0, 0) + self.topology_file_edit = QLineEdit() + self.topology_file_edit.editingFinished.connect( + self._on_editor_changed + ) + topology_layout.addWidget(self.topology_file_edit, stretch=1) + topology_button = QPushButton("Browse...") + topology_button.clicked.connect(self._choose_topology_file) + topology_layout.addWidget(topology_button) + form.addRow("Topology file", topology_row) + + energy_row = QWidget() + energy_layout = QHBoxLayout(energy_row) + energy_layout.setContentsMargins(0, 0, 0, 0) + self.energy_file_edit = QLineEdit() + self.energy_file_edit.editingFinished.connect(self._on_editor_changed) + energy_layout.addWidget(self.energy_file_edit, stretch=1) + energy_button = QPushButton("Browse...") + energy_button.clicked.connect(self._choose_energy_file) + energy_layout.addWidget(energy_button) + form.addRow("Energy file", energy_row) + + output_row = QWidget() + output_layout = QHBoxLayout(output_row) + output_layout.setContentsMargins(0, 0, 0, 0) + self.output_dir_edit = QLineEdit() + self.output_dir_edit.setToolTip( + "Target folder for this project's extracted XYZ frames. Leave " + "blank to use the preview-generated default." + ) + self.output_dir_edit.editingFinished.connect(self._on_editor_changed) + output_layout.addWidget(self.output_dir_edit, stretch=1) + output_button = QPushButton("Browse...") + output_button.clicked.connect(self._choose_output_dir) + output_layout.addWidget(output_button) + form.addRow("Output folder", output_row) + + self.cutoff_spin = QDoubleSpinBox() + self.cutoff_spin.setRange(0.0, 1.0e12) + self.cutoff_spin.setDecimals(3) + self.cutoff_spin.setSingleStep(100.0) + self.cutoff_spin.setSuffix(" fs") + self.cutoff_spin.setValue(DEFAULT_TIME_CUTOFF_FS) + self.cutoff_spin.valueChanged.connect(self._on_editor_changed) + form.addRow("Time cutoff", self.cutoff_spin) + + self.include_restart_duplicates_box = QCheckBox( + "Include duplicate restart frames" + ) + self.include_restart_duplicates_box.setToolTip( + "Export duplicate frames from overlapping simulation restarts. " + "Leave this off for the cleaned continuation trajectory." + ) + self.include_restart_duplicates_box.toggled.connect( + self._on_editor_changed + ) + form.addRow("", self.include_restart_duplicates_box) + + self.preview_summary_label = QLabel( + "Preview the trajectory to verify the generated output folder." + ) + self.preview_summary_label.setWordWrap(True) + self.preview_summary_label.setFrameShape(QFrame.Shape.StyledPanel) + form.addRow("", self.preview_summary_label) + + def _load_item(self, item: MDTrajectoryBatchItem) -> None: + self._loading = True + self.project_dir_edit.setText( + "" if item.project_dir is None else str(item.project_dir) + ) + self.trajectory_file_edit.setText( + "" if item.trajectory_file is None else str(item.trajectory_file) + ) + self.topology_file_edit.setText( + "" if item.topology_file is None else str(item.topology_file) + ) + self.energy_file_edit.setText( + "" if item.energy_file is None else str(item.energy_file) + ) + self.output_dir_edit.setText( + "" if item.output_dir is None else str(item.output_dir) + ) + if item.output_dir is not None: + self._last_suggested_output_dir = item.output_dir.resolve() + self.cutoff_spin.setValue(float(item.cutoff_fs)) + self.include_restart_duplicates_box.setChecked( + item.include_restart_duplicates + ) + self._loading = False + self._refresh_header() + self._refresh_project_reference() + + def _set_settings_visible(self, visible: bool) -> None: + self.settings_group.setVisible(bool(visible)) + self.toggle_button.setChecked(bool(visible)) + self.toggle_button.setText("Hide Settings" if visible else "Settings") + parent_item = self._list_item() + if parent_item is not None: + parent_item.setSizeHint(self.sizeHint()) + + def _list_item(self) -> QListWidgetItem | None: + parent = self.parent() + while parent is not None and not isinstance(parent, QListWidget): + parent = parent.parent() + if not isinstance(parent, QListWidget): + return None + for row in range(parent.count()): + list_item = parent.item(row) + if parent.itemWidget(list_item) is self: + return list_item + return None + + def _choose_project_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select SAXSShell project folder", + _dialog_start_dir(self.project_dir_edit.text()), + ) + if not selected: + return + project_dir = Path(selected).expanduser().resolve() + self._load_item( + replace( + _queue_item_from_project_defaults( + project_dir, + item_id=self.item_id, + ), + cutoff_fs=float(self.cutoff_spin.value()), + include_restart_duplicates=( + self.include_restart_duplicates_box.isChecked() + ), + ) + ) + self._on_editor_changed() + + def _choose_trajectory_file(self) -> None: + selected, _filter = QFileDialog.getOpenFileName( + self, + "Select trajectory file", + _dialog_start_dir( + self.trajectory_file_edit.text(), + self.project_dir_edit.text(), + ), + "Trajectory files (*.xyz *.pdb);;All files (*)", + ) + if not selected: + return + self.trajectory_file_edit.setText(selected) + self._on_editor_changed() + + def _choose_topology_file(self) -> None: + selected, _filter = QFileDialog.getOpenFileName( + self, + "Select topology file", + _dialog_start_dir( + self.topology_file_edit.text(), + self.project_dir_edit.text(), + ), + "Topology files (*.pdb *.gro *.top *.psf);;All files (*)", + ) + if not selected: + return + self.topology_file_edit.setText(selected) + self._on_editor_changed() + + def _choose_energy_file(self) -> None: + selected, _filter = QFileDialog.getOpenFileName( + self, + "Select CP2K energy file", + _dialog_start_dir( + self.energy_file_edit.text(), + self.project_dir_edit.text(), + ), + "Energy files (*.ener *.out *.txt);;All files (*)", + ) + if not selected: + return + self.energy_file_edit.setText(selected) + self._on_editor_changed() + + def _choose_output_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select output folder for extracted XYZ frames", + _dialog_start_dir( + self.output_dir_edit.text(), + self.trajectory_file_edit.text(), + self.project_dir_edit.text(), + ), + ) + if not selected: + return + self.output_dir_edit.setText(selected) + self._on_editor_changed() + + def _preview_from_button(self) -> None: + try: + self.preview_selection() + except Exception as exc: + QMessageBox.warning(self, "Unable to preview trajectory", str(exc)) + self.preview_summary_label.setText(str(exc)) + self.set_status("Preview failed") + self._on_editor_changed() + + def _on_editor_changed(self, *_args) -> None: + if self._loading: + return + try: + self.collect_item() + self.set_status("Ready") + except Exception: + self._refresh_header() + self._refresh_project_reference() + self.settings_changed.emit(self.item_id) + + def _refresh_header(self) -> None: + self.title_label.setText(self._item.display_name()) + + def _refresh_project_reference(self) -> None: + project_dir = _optional_path(self.project_dir_edit.text()) + self.project_reference_label.setText( + _project_reference_text(project_dir) + ) + + +class MDTrajectoryBatchWorker(QObject): + item_started = Signal(str, int, int) + item_progress = Signal(str, int, int, str) + item_finished = Signal(str, object) + item_failed = Signal(str, str) + log = Signal(str) + status = Signal(str) + finished = Signal(object) + failed = Signal(str, str) + + def __init__( + self, + queue_entries: list[tuple[str, MDTrajectoryBatchJob]], + ) -> None: + super().__init__() + self.queue_entries = list(queue_entries) + self._cancel_requested = threading.Event() + self._project_manager = SAXSProjectManager() + + def request_cancel(self) -> None: + self._cancel_requested.set() + + @Slot() + def run(self) -> None: + results: list[MDTrajectoryBatchResult] = [] + total_items = len(self.queue_entries) + for index, (item_id, job) in enumerate( + self.queue_entries, + start=1, + ): + if self._cancel_requested.is_set(): + self.log.emit("Batch queue stopped before the next project.") + break + self.item_started.emit(item_id, index, total_items) + self.status.emit( + f"Running {index}/{total_items}: {job.project_dir.name}" + ) + self.log.emit(f"Starting {index}/{total_items}: {job.project_dir}") + try: + result = self._run_job(item_id, job) + except Exception as exc: + message = str(exc) + self.item_failed.emit(item_id, message) + self.failed.emit(item_id, message) + return + results.append(result) + self.item_finished.emit(item_id, result) + self.status.emit("MD trajectory batch queue finished") + self.finished.emit(results) + + def _run_job( + self, + item_id: str, + job: MDTrajectoryBatchJob, + ) -> MDTrajectoryBatchResult: + settings = self._project_manager.load_project(job.project_dir) + workflow = MDTrajectoryWorkflow( + trajectory_file=job.trajectory_file, + topology_file=job.topology_file, + energy_file=job.energy_file, + include_restart_duplicates=job.include_restart_duplicates, + ) + self.item_progress.emit( + item_id, + 0, + 1, + "Inspecting trajectory", + ) + selection = workflow.preview_selection( + use_cutoff=True, + cutoff_fs=job.cutoff_fs, + output_dir=job.output_dir, + ) + summary = workflow.inspect() + self.log.emit( + f"[{job.project_dir.name}] Selected " + f"{selection.preview.selected_frames} of " + f"{selection.preview.total_frames} frame(s); output " + f"{selection.output_dir}" + ) + duplicate_source_frames = int( + summary.get("duplicate_source_frames", 0) + ) + if duplicate_source_frames and job.include_restart_duplicates: + self.log.emit( + f"[{job.project_dir.name}] Included " + f"{duplicate_source_frames} duplicate source frame(s) from " + "overlapping trajectory chunks." + ) + elif duplicate_source_frames: + self.log.emit( + f"[{job.project_dir.name}] Skipped " + f"{duplicate_source_frames} duplicate source frame(s) from " + "overlapping trajectory chunks." + ) + export_result = workflow.export_frames( + use_cutoff=True, + cutoff_fs=job.cutoff_fs, + output_dir=job.output_dir, + progress_callback=( + lambda processed, total, message: self.item_progress.emit( + item_id, + processed, + total, + message, + ) + ), + ) + settings.trajectory_file = str(job.trajectory_file) + settings.topology_file = ( + None if job.topology_file is None else str(job.topology_file) + ) + settings.energy_file = str(job.energy_file) + settings.frames_dir = str(export_result.output_dir) + self._project_manager.save_project(settings) + self.log.emit( + f"[{job.project_dir.name}] Registered XYZ frames folder: " + f"{export_result.output_dir}" + ) + return MDTrajectoryBatchResult( + project_dir=job.project_dir, + output_dir=export_result.output_dir, + written_count=len(export_result.written_files), + selected_frames=export_result.selection.preview.selected_frames, + cutoff_fs=job.cutoff_fs, + include_restart_duplicates=job.include_restart_duplicates, + metadata_file=export_result.metadata_file, + ) + + +class MDTrajectoryBatchQueueWindow(QMainWindow): + """Queue MD trajectory frame extraction for multiple projects.""" + + project_paths_registered = Signal(object) + + def __init__( + self, + initial_project_dir: str | Path | None = None, + *, + initial_trajectory_file: str | Path | None = None, + initial_topology_file: str | Path | None = None, + initial_energy_file: str | Path | None = None, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._widgets_by_id: dict[str, MDTrajectoryBatchItemWidget] = {} + self._run_thread: QThread | None = None + self._run_worker: MDTrajectoryBatchWorker | None = None + self._initial_project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + self._initial_trajectory_file = ( + None + if initial_trajectory_file is None + else Path(initial_trajectory_file).expanduser().resolve() + ) + self._initial_topology_file = ( + None + if initial_topology_file is None + else Path(initial_topology_file).expanduser().resolve() + ) + self._initial_energy_file = ( + None + if initial_energy_file is None + else Path(initial_energy_file).expanduser().resolve() + ) + self._build_ui() + if ( + self._initial_project_dir is not None + or self._initial_trajectory_file is not None + or self._initial_energy_file is not None + ): + self._add_current_project() + + def closeEvent(self, event) -> None: + if self._run_thread is not None and self._run_thread.isRunning(): + self._request_cancel() + self.hide() + while ( + self._run_thread is not None and self._run_thread.isRunning() + ): + QApplication.processEvents() + if self._run_thread is not None: + self._run_thread.wait(50) + event.accept() + return + super().closeEvent(event) + + def add_queue_item( + self, + item: MDTrajectoryBatchItem | None = None, + ) -> MDTrajectoryBatchItemWidget: + resolved_item = item or MDTrajectoryBatchItem(item_id=_new_item_id()) + list_item = QListWidgetItem() + list_item.setData(Qt.ItemDataRole.UserRole, resolved_item.item_id) + self.queue_list.addItem(list_item) + widget = MDTrajectoryBatchItemWidget( + resolved_item, + parent=self.queue_list, + ) + widget.settings_changed.connect(self._on_item_settings_changed) + widget.remove_requested.connect(self._remove_item) + widget.duplicate_requested.connect(self._duplicate_item) + self._widgets_by_id[resolved_item.item_id] = widget + list_item.setSizeHint(widget.sizeHint()) + self.queue_list.setItemWidget(list_item, widget) + self.queue_list.setCurrentItem(list_item) + self._refresh_order_labels() + return widget + + def queue_jobs_in_order(self) -> list[tuple[str, MDTrajectoryBatchJob]]: + entries: list[tuple[str, MDTrajectoryBatchJob]] = [] + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id[item_id] + entries.append((item_id, widget.job())) + return entries + + def _build_ui(self) -> None: + self.setWindowTitle("SAXSShell MD Trajectory Batch Queue") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1080, 820) + + central = QWidget() + root = QVBoxLayout(central) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + controls = QHBoxLayout() + self.add_current_button = QPushButton("Add Current Project") + self.add_current_button.clicked.connect(self._add_current_project) + controls.addWidget(self.add_current_button) + self.add_project_button = QPushButton("Add Projects...") + self.add_project_button.clicked.connect(self._choose_projects_to_add) + controls.addWidget(self.add_project_button) + controls.addStretch(1) + root.addLayout(controls) + + self.queue_list = QListWidget() + self.queue_list.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.queue_list.setDragDropMode( + QAbstractItemView.DragDropMode.InternalMove + ) + self.queue_list.setDefaultDropAction(Qt.DropAction.MoveAction) + self.queue_list.setAlternatingRowColors(True) + self.queue_list.setStyleSheet( + "QListWidget::item:selected { background: transparent; }" + "QListWidget::item:hover { background: transparent; }" + "QListWidget::item { margin: 3px; }" + ) + self.queue_list.model().rowsMoved.connect(self._refresh_order_labels) + self.queue_list.itemSelectionChanged.connect( + self._refresh_item_selection_styles + ) + root.addWidget(self.queue_list, stretch=1) + + run_group = QFrame() + run_group.setFrameShape(QFrame.Shape.StyledPanel) + run_layout = QVBoxLayout(run_group) + run_buttons = QHBoxLayout() + self.run_button = QPushButton("Run Complete Queue") + self.run_button.clicked.connect(self._start_queue) + run_buttons.addWidget(self.run_button) + self.cancel_button = QPushButton("Stop Queue") + self.cancel_button.setEnabled(False) + self.cancel_button.clicked.connect(self._request_cancel) + run_buttons.addWidget(self.cancel_button) + run_buttons.addStretch(1) + run_layout.addLayout(run_buttons) + self.queue_status_label = QLabel("Queue idle") + run_layout.addWidget(self.queue_status_label) + self.console = QTextEdit() + self.console.setReadOnly(True) + self.console.setMinimumHeight(160) + run_layout.addWidget(self.console) + root.addWidget(run_group) + + self.setCentralWidget(central) + self.statusBar().showMessage("Ready") + + def _add_current_project(self) -> None: + if ( + self._initial_project_dir is None + and self._initial_trajectory_file is None + and self._initial_energy_file is None + ): + QMessageBox.information( + self, + "No active project", + "The main UI did not provide an active project reference.", + ) + return + item = ( + _queue_item_from_project_defaults(self._initial_project_dir) + if self._initial_project_dir is not None + else MDTrajectoryBatchItem(item_id=_new_item_id()) + ) + self.add_queue_item( + replace( + item, + trajectory_file=( + self._initial_trajectory_file or item.trajectory_file + ), + topology_file=( + self._initial_topology_file or item.topology_file + ), + energy_file=self._initial_energy_file or item.energy_file, + ) + ) + + def _choose_projects_to_add(self) -> None: + selected_dirs = _choose_existing_directories( + self, + title="Select SAXSShell project folders", + start_dir=self._initial_project_dir or Path.home(), + ) + if not selected_dirs: + return + for project_dir in selected_dirs: + self.add_queue_item(_queue_item_from_project_defaults(project_dir)) + + def _on_item_settings_changed(self, _item_id: str) -> None: + self._refresh_order_labels() + + def _refresh_order_labels(self, *_args) -> None: + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id.get(item_id) + if widget is None: + continue + widget.title_label.setText( + f"{row + 1}. {widget.item().display_name()}" + ) + list_item.setSizeHint(widget.sizeHint()) + self._refresh_item_selection_styles() + + def _refresh_item_selection_styles(self) -> None: + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_selected(list_item.isSelected()) + + def _remove_item(self, item_id: str) -> None: + if self._run_thread is not None and self._run_thread.isRunning(): + return + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + if str(list_item.data(Qt.ItemDataRole.UserRole)) == item_id: + self.queue_list.takeItem(row) + break + self._widgets_by_id.pop(item_id, None) + self._refresh_order_labels() + + def _duplicate_item(self, item_id: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is None: + return + try: + item = widget.collect_item() + except Exception: + item = widget.item() + self.add_queue_item(replace(item, item_id=_new_item_id())) + + def _set_running(self, running: bool) -> None: + self.add_current_button.setEnabled(not running) + self.add_project_button.setEnabled(not running) + self.run_button.setEnabled(not running) + self.cancel_button.setEnabled(running) + self.queue_list.setDragEnabled(not running) + self.queue_list.setAcceptDrops(not running) + for widget in self._widgets_by_id.values(): + widget.set_locked(running) + + def _start_queue(self) -> None: + if self.queue_list.count() == 0: + QMessageBox.information( + self, + "MD trajectory batch queue", + "Add at least one project before running the queue.", + ) + return + try: + entries = self.queue_jobs_in_order() + except Exception as exc: + QMessageBox.warning( + self, + "Invalid MD trajectory batch settings", + str(exc), + ) + return + + self.console.clear() + self._set_running(True) + self.queue_status_label.setText( + f"Running 0/{len(entries)} queued extraction(s)" + ) + for widget in self._widgets_by_id.values(): + widget.set_progress(0, 1) + widget.set_status("Queued") + + self._run_thread = QThread(self) + self._run_worker = MDTrajectoryBatchWorker(entries) + self._run_worker.moveToThread(self._run_thread) + self._run_thread.started.connect(self._run_worker.run) + self._run_worker.item_started.connect(self._on_item_started) + self._run_worker.item_progress.connect(self._on_item_progress) + self._run_worker.item_finished.connect(self._on_item_finished) + self._run_worker.item_failed.connect(self._on_item_failed) + self._run_worker.log.connect(self._append_log) + self._run_worker.status.connect(self._on_status) + self._run_worker.finished.connect(self._on_queue_finished) + self._run_worker.failed.connect(self._on_queue_failed) + self._run_worker.finished.connect(self._run_thread.quit) + self._run_worker.failed.connect(self._run_thread.quit) + self._run_thread.finished.connect(self._cleanup_run_thread) + self._run_thread.finished.connect(self._run_thread.deleteLater) + self._run_thread.start() + + def _request_cancel(self) -> None: + self.cancel_button.setEnabled(False) + self.queue_status_label.setText( + "Stopping queue after the active project finishes" + ) + self._append_log( + "Stop requested; the current project will finish before the " + "queue exits." + ) + if self._run_worker is not None: + self._run_worker.request_cancel() + + def _append_log(self, message: str) -> None: + self.console.append(message) + + def _on_status(self, message: str) -> None: + self.statusBar().showMessage(message) + self.queue_status_label.setText(message) + + def _on_item_started( + self, + item_id: str, + index: int, + total: int, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status(f"Running {index}/{total}") + widget.set_progress(0, 1) + self.queue_status_label.setText( + f"Running {index}/{total} queued extraction(s)" + ) + + def _on_item_progress( + self, + item_id: str, + processed: int, + total: int, + message: str, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_progress(processed, total) + widget.set_status(message) + + def _on_item_finished( + self, + item_id: str, + result: MDTrajectoryBatchResult, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is None: + return + widget.set_progress( + result.written_count, + max(result.selected_frames, 1), + ) + widget.set_status("Complete") + self.project_paths_registered.emit( + { + "project_dir": result.project_dir, + "frames_dir": result.output_dir, + } + ) + + def _on_item_failed(self, item_id: str, message: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status("Failed") + self._append_log(message) + + def _on_queue_finished(self, results: object) -> None: + self._set_running(False) + result_count = len(results) if isinstance(results, list) else 0 + self.queue_status_label.setText( + f"Queue finished: {result_count} extraction(s) saved" + ) + self.statusBar().showMessage("MD trajectory batch queue finished") + + def _on_queue_failed(self, item_id: str, message: str) -> None: + self._set_running(False) + self.queue_status_label.setText("Queue stopped after a failure") + self.statusBar().showMessage( + "MD trajectory batch queue failed", + 5000, + ) + QMessageBox.warning( + self, + "MD trajectory batch queue failed", + f"Queue item {item_id} failed:\n{message}", + ) + + def _cleanup_run_thread(self) -> None: + self._run_thread = None + self._run_worker = None + + +def launch_mdtrajectory_batch_queue_ui( + initial_project_dir: str | Path | None = None, + *, + initial_trajectory_file: str | Path | None = None, + initial_topology_file: str | Path | None = None, + initial_energy_file: str | Path | None = None, +) -> int: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication([]) + configure_saxshell_application(app) + window = MDTrajectoryBatchQueueWindow( + initial_project_dir=initial_project_dir, + initial_trajectory_file=initial_trajectory_file, + initial_topology_file=initial_topology_file, + initial_energy_file=initial_energy_file, + ) + window.show() + return int(app.exec()) + + +__all__ = [ + "DEFAULT_TIME_CUTOFF_FS", + "MDTrajectoryBatchItem", + "MDTrajectoryBatchItemWidget", + "MDTrajectoryBatchJob", + "MDTrajectoryBatchQueueWindow", + "MDTrajectoryBatchResult", + "MDTrajectoryBatchWorker", + "launch_mdtrajectory_batch_queue_ui", +] diff --git a/src/saxshell/mdtrajectory/ui/cutoff_panel.py b/src/saxshell/mdtrajectory/ui/cutoff_panel.py index beed70b..4ce03b3 100644 --- a/src/saxshell/mdtrajectory/ui/cutoff_panel.py +++ b/src/saxshell/mdtrajectory/ui/cutoff_panel.py @@ -227,7 +227,7 @@ def _build_ui(self) -> None: self.window_spin = QSpinBox() self.window_spin.setRange(1, 10**6) - self.window_spin.setValue(3) + self.window_spin.setValue(2) self.window_spin.setToolTip( "Number of consecutive energy samples that must satisfy the " "steady-state criteria." diff --git a/src/saxshell/mdtrajectory/ui/export_panel.py b/src/saxshell/mdtrajectory/ui/export_panel.py index e85fe8c..42a7610 100644 --- a/src/saxshell/mdtrajectory/ui/export_panel.py +++ b/src/saxshell/mdtrajectory/ui/export_panel.py @@ -60,6 +60,18 @@ def _build_ui(self) -> None: self.use_cutoff_box.toggled.connect(self._handle_use_cutoff_toggled) form.addRow("", self.use_cutoff_box) + self.include_restart_duplicates_box = QCheckBox( + "Include duplicate restart frames" + ) + self.include_restart_duplicates_box.setToolTip( + "Export duplicate frames from overlapping simulation restarts. " + "Leave this off for the cleaned continuation trajectory." + ) + self.include_restart_duplicates_box.toggled.connect( + lambda _checked: self.settings_changed.emit() + ) + form.addRow("", self.include_restart_duplicates_box) + self.post_cutoff_stride_box = QCheckBox( "After cutoff, keep every Nth frame" ) @@ -170,6 +182,9 @@ def use_post_cutoff_stride(self) -> bool: def get_post_cutoff_stride(self) -> int: return self.post_cutoff_stride_spin.value() + def include_restart_duplicates(self) -> bool: + return self.include_restart_duplicates_box.isChecked() + def set_selection_summary(self, text: str) -> None: self.selection_box.setPlainText(text) @@ -188,6 +203,7 @@ def set_controls_enabled(self, enabled: bool) -> None: if self.output_dir_button is not None: self.output_dir_button.setEnabled(enabled) self.use_cutoff_box.setEnabled(enabled) + self.include_restart_duplicates_box.setEnabled(enabled) self.post_cutoff_stride_box.setEnabled( enabled and self.use_cutoff_box.isChecked() ) diff --git a/src/saxshell/mdtrajectory/ui/main_window.py b/src/saxshell/mdtrajectory/ui/main_window.py index bab8aa7..3fc2cd9 100644 --- a/src/saxshell/mdtrajectory/ui/main_window.py +++ b/src/saxshell/mdtrajectory/ui/main_window.py @@ -74,6 +74,7 @@ def __init__( manager: TrajectoryManager | None = None, summary: dict[str, object] | None = None, reload_trajectory: bool = True, + include_restart_duplicates: bool = False, ) -> None: super().__init__() self.trajectory_file = trajectory_file @@ -82,6 +83,7 @@ def __init__( self.manager = manager self.summary = summary self.reload_trajectory = reload_trajectory + self.include_restart_duplicates = bool(include_restart_duplicates) @Slot() def run(self) -> None: @@ -104,6 +106,9 @@ def run(self) -> None: input_file=self.trajectory_file, topology_file=self.topology_file, backend="auto", + include_restart_duplicates=( + self.include_restart_duplicates + ), ) summary = manager.inspect() completed_steps += 1 @@ -474,6 +479,9 @@ def inspect_trajectory(self) -> None: previous_trajectory = self.state.trajectory_file previous_topology = self.state.topology_file previous_energy = self.state.energy_file + previous_include_restart_duplicates = ( + self.state.include_restart_duplicates + ) if trajectory_file is None: raise ValueError("No trajectory file selected.") @@ -482,6 +490,8 @@ def inspect_trajectory(self) -> None: self.manager is None or trajectory_file != previous_trajectory or topology_file != previous_topology + or self.export_panel.include_restart_duplicates() + != previous_include_restart_duplicates ) energy_changed = energy_file != previous_energy @@ -491,6 +501,9 @@ def inspect_trajectory(self) -> None: self.state.start = self.trajectory_panel.get_start() self.state.stop = self.trajectory_panel.get_stop() self.state.stride = self.trajectory_panel.get_stride() + self.state.include_restart_duplicates = ( + self.export_panel.include_restart_duplicates() + ) self._update_suggested_output_dir() registration_message = self._register_project_file_inputs() if registration_message is not None: @@ -525,6 +538,9 @@ def inspect_trajectory(self) -> None: topology_file=topology_file, energy_file=energy_file, reload_trajectory=True, + include_restart_duplicates=( + self.state.include_restart_duplicates + ), ) return @@ -549,6 +565,9 @@ def inspect_trajectory(self) -> None: manager=self.manager, summary=self._last_summary, reload_trajectory=False, + include_restart_duplicates=( + self.state.include_restart_duplicates + ), ) return @@ -601,6 +620,7 @@ def export_frames(self) -> None: self._sync_state_from_controls() self.state.output_dir = output_dir min_time_fs = self._resolved_export_cutoff() + self._apply_restart_duplicate_mode_to_manager() preview = self.manager.preview_selection( start=self.state.start, stop=self.state.stop, @@ -650,6 +670,9 @@ def _sync_state_from_controls(self) -> None: self.state.post_cutoff_stride = ( self.export_panel.get_post_cutoff_stride() ) + self.state.include_restart_duplicates = ( + self.export_panel.include_restart_duplicates() + ) self.state.selected_cutoff_fs = self.cutoff_panel.get_selected_cutoff() self.state.suggested_cutoff_fs = ( self.cutoff_panel.get_suggested_cutoff() @@ -684,6 +707,7 @@ def _refresh_selection_preview(self) -> None: if self.state.use_cutoff_for_export: min_time_fs = self._resolved_export_cutoff() + self._apply_restart_duplicate_mode_to_manager() preview = self.manager.preview_selection( start=self.state.start, stop=self.state.stop, @@ -719,6 +743,7 @@ def _start_inspection_worker( manager: TrajectoryManager | None = None, summary: dict[str, object] | None = None, reload_trajectory: bool = True, + include_restart_duplicates: bool = False, ) -> None: self._set_operation_busy(True, "Inspecting trajectory...") self.export_panel.set_busy_progress("Inspection progress: starting...") @@ -731,6 +756,7 @@ def _start_inspection_worker( manager=manager, summary=summary, reload_trajectory=reload_trajectory, + include_restart_duplicates=include_restart_duplicates, ) self._inspect_worker.moveToThread(self._inspect_thread) self._inspect_thread.started.connect(self._inspect_worker.run) @@ -907,6 +933,8 @@ def _handle_export_finished(self, result: ExportResult) -> None: f"Start: {self.state.start}", f"Stop: {self.state.stop}", f"Stride: {self.state.stride}", + "Restart duplicate frames: " + f"{'included' if self.state.include_restart_duplicates else 'skipped'}", ] if result.applied_cutoff_fs is not None: lines.append(f"Applied cutoff: {result.applied_cutoff_fs:.3f} fs") @@ -984,6 +1012,8 @@ def _format_selection_summary(self, preview) -> str: f"Stop: {preview.stop}", f"Stride: {preview.stride}", f"Time-tagged frames: {preview.time_metadata_frames}", + "Restart duplicate frames: " + f"{'included' if self.state.include_restart_duplicates else 'skipped'}", ] ) if preview.min_time_fs is not None: @@ -1007,6 +1037,17 @@ def _format_selection_summary(self, preview) -> str: ) return "\n".join(lines) + def _apply_restart_duplicate_mode_to_manager(self) -> None: + if self.manager is None: + return + setter = getattr( + self.manager, + "set_include_restart_duplicates", + None, + ) + if callable(setter): + setter(self.state.include_restart_duplicates) + def _update_suggested_output_dir( self, *, diff --git a/src/saxshell/mdtrajectory/ui/state.py b/src/saxshell/mdtrajectory/ui/state.py index bbf3cfe..7bac0c3 100644 --- a/src/saxshell/mdtrajectory/ui/state.py +++ b/src/saxshell/mdtrajectory/ui/state.py @@ -19,7 +19,7 @@ class MDTrajectoryAppState: temp_target_k: float = 300.0 temp_tol_k: float = 1.0 - window: int = 3 + window: int = 2 suggested_cutoff_fs: float | None = None selected_cutoff_fs: float | None = None @@ -28,3 +28,4 @@ class MDTrajectoryAppState: use_cutoff_for_export: bool = True use_post_cutoff_stride: bool = False post_cutoff_stride: int = 1 + include_restart_duplicates: bool = False diff --git a/src/saxshell/mdtrajectory/workflow.py b/src/saxshell/mdtrajectory/workflow.py index 93511ce..f23a4fb 100644 --- a/src/saxshell/mdtrajectory/workflow.py +++ b/src/saxshell/mdtrajectory/workflow.py @@ -5,12 +5,17 @@ from pathlib import Path from re import sub +from saxshell.mdtrajectory.frame.assertions import ( + MDTrajectoryAssertionResult, + validate_xyz_export_against_source, +) from saxshell.mdtrajectory.frame.cp2k_ener import CP2KEnergyData from saxshell.mdtrajectory.frame.cutoff_analysis import ( CP2KEnergyAnalyzer, SteadyStateResult, ) from saxshell.mdtrajectory.frame.manager import ( + ExportProgressCallback, FrameSelectionPreview, TrajectoryManager, ) @@ -97,6 +102,7 @@ class MDTrajectorySelectionResult: preview: FrameSelectionPreview output_dir: Path applied_cutoff_fs: float | None + include_restart_duplicates: bool = False def to_dict(self) -> dict[str, object]: preview = self.preview @@ -114,6 +120,7 @@ def to_dict(self) -> dict[str, object]: "last_frame_index": preview.last_frame_index, "first_time_fs": preview.first_time_fs, "last_time_fs": preview.last_time_fs, + "include_restart_duplicates": self.include_restart_duplicates, } @@ -148,6 +155,7 @@ def __init__( topology_file: str | Path | None = None, energy_file: str | Path | None = None, backend: str = "auto", + include_restart_duplicates: bool = False, ) -> None: self.trajectory_file = Path(trajectory_file) self.topology_file = ( @@ -157,10 +165,12 @@ def __init__( Path(energy_file) if energy_file is not None else None ) self.backend = backend + self.include_restart_duplicates = bool(include_restart_duplicates) self.manager = TrajectoryManager( input_file=self.trajectory_file, topology_file=self.topology_file, backend=backend, + include_restart_duplicates=self.include_restart_duplicates, ) self.summary: dict[str, object] | None = None self.energy_data: CP2KEnergyData | None = None @@ -173,6 +183,19 @@ def inspect(self) -> dict[str, object]: self.summary = self.manager.inspect() return dict(self.summary) + def set_include_restart_duplicates( + self, + include_restart_duplicates: bool, + ) -> None: + """Choose whether restart-overlap duplicate frames are + exposed.""" + include_restart_duplicates = bool(include_restart_duplicates) + if self.include_restart_duplicates == include_restart_duplicates: + return + self.include_restart_duplicates = include_restart_duplicates + self.manager.set_include_restart_duplicates(include_restart_duplicates) + self.summary = None + def load_energy(self) -> CP2KEnergyData: """Load the configured CP2K energy file.""" if self.energy_file is None: @@ -190,7 +213,7 @@ def suggest_cutoff( *, temp_target_k: float, temp_tol_k: float = 1.0, - window: int = 3, + window: int = 2, ) -> SteadyStateResult: """Suggest a steady-state cutoff from the loaded energy data.""" analyzer = CP2KEnergyAnalyzer(self.load_energy()) @@ -232,8 +255,11 @@ def preview_selection( use_cutoff: bool = False, cutoff_fs: float | None = None, output_dir: str | Path | None = None, + include_restart_duplicates: bool | None = None, ) -> MDTrajectorySelectionResult: """Preview the selected frames and output target directory.""" + if include_restart_duplicates is not None: + self.set_include_restart_duplicates(include_restart_duplicates) self.inspect() applied_cutoff_fs = self.resolve_cutoff( use_cutoff=use_cutoff, @@ -259,6 +285,7 @@ def preview_selection( preview=preview, output_dir=resolved_output_dir, applied_cutoff_fs=applied_cutoff_fs, + include_restart_duplicates=self.include_restart_duplicates, ) def export_frames( @@ -271,6 +298,8 @@ def export_frames( post_cutoff_stride: int = 1, use_cutoff: bool = False, cutoff_fs: float | None = None, + include_restart_duplicates: bool | None = None, + progress_callback: ExportProgressCallback | None = None, ) -> MDTrajectoryExportResult: """Write the current frame selection to disk.""" selection = self.preview_selection( @@ -281,6 +310,7 @@ def export_frames( use_cutoff=use_cutoff, cutoff_fs=cutoff_fs, output_dir=output_dir, + include_restart_duplicates=include_restart_duplicates, ) if selection.preview.selected_frames == 0: raise ValueError("No frames match the current selection settings.") @@ -299,6 +329,7 @@ def export_frames( stride=stride, min_time_fs=selection.applied_cutoff_fs, post_cutoff_stride=post_cutoff_stride, + progress_callback=progress_callback, ) metadata_file = self._write_export_metadata( selection=selection, @@ -312,6 +343,28 @@ def export_frames( metadata_file=metadata_file, ) + def validate_export( + self, + frame_dir: str | Path, + *, + coordinate_lines: int = 3, + coordinate_tolerance: float = 1.0e-9, + expect_contiguous: bool = False, + strict_source_duplicates: bool = False, + max_issues: int = 20, + ) -> MDTrajectoryAssertionResult: + """Run export mapping assertions against the source + trajectory.""" + return validate_xyz_export_against_source( + self.trajectory_file, + frame_dir, + coordinate_lines=coordinate_lines, + coordinate_tolerance=coordinate_tolerance, + expect_contiguous=expect_contiguous, + strict_source_duplicates=strict_source_duplicates, + max_issues=max_issues, + ) + def _write_export_metadata( self, *, diff --git a/src/saxshell/pdf/debyer/__init__.py b/src/saxshell/pdf/debyer/__init__.py index f5ab489..e777628 100644 --- a/src/saxshell/pdf/debyer/__init__.py +++ b/src/saxshell/pdf/debyer/__init__.py @@ -7,6 +7,7 @@ SUPPORTED_DEBYER_MODES, SUPPORTED_PLOT_REPRESENTATIONS, TOTAL_SCATTERING_PAPER_URL, + DebyerFitMetrics, DebyerFrameInspection, DebyerPDFCalculation, DebyerPDFCalculationSummary, @@ -21,9 +22,12 @@ calculate_number_density, check_debyer_runtime, classify_partial_pair, + compute_experimental_fit_metrics, convert_distribution_values, + default_parallel_debyer_jobs, estimate_partial_peak_markers, find_partial_peak_markers, + infer_default_solute_elements, inspect_frames_dir, list_saved_debyer_calculations, load_debyer_calculation, @@ -37,6 +41,7 @@ "TOTAL_SCATTERING_PAPER_URL", "SUPPORTED_DEBYER_MODES", "SUPPORTED_PLOT_REPRESENTATIONS", + "DebyerFitMetrics", "DebyerFrameInspection", "DebyerPeakFinderSettings", "DebyerPeakMarker", @@ -51,9 +56,12 @@ "calculate_number_density", "check_debyer_runtime", "classify_partial_pair", + "compute_experimental_fit_metrics", "convert_distribution_values", + "default_parallel_debyer_jobs", "estimate_partial_peak_markers", "find_partial_peak_markers", + "infer_default_solute_elements", "inspect_frames_dir", "list_saved_debyer_calculations", "load_debyer_calculation", diff --git a/src/saxshell/pdf/debyer/ui/batch_queue_window.py b/src/saxshell/pdf/debyer/ui/batch_queue_window.py new file mode 100644 index 0000000..8c31fd6 --- /dev/null +++ b/src/saxshell/pdf/debyer/ui/batch_queue_window.py @@ -0,0 +1,1793 @@ +from __future__ import annotations + +import json +import threading +import uuid +from dataclasses import dataclass, replace +from pathlib import Path + +from PySide6.QtCore import QObject, Qt, QThread, Signal, Slot +from PySide6.QtWidgets import ( + QAbstractItemView, + QApplication, + QCheckBox, + QComboBox, + QFileDialog, + QFormLayout, + QFrame, + QGridLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QListView, + QListWidget, + QListWidgetItem, + QMainWindow, + QMessageBox, + QProgressBar, + QPushButton, + QSizePolicy, + QSpinBox, + QTextEdit, + QToolButton, + QTreeView, + QVBoxLayout, + QWidget, +) + +from saxshell.pdf.debyer.workflow import ( + SUPPORTED_DEBYER_MODES, + DebyerPDFCalculation, + DebyerPDFSettings, + DebyerPDFWorkflow, + calculate_number_density, + check_debyer_runtime, + default_parallel_debyer_jobs, + infer_default_solute_elements, + inspect_frames_dir, + list_saved_debyer_calculations, + load_debyer_calculation, + rewrite_debyer_calculation_output, + write_debyer_calculation_metadata, +) +from saxshell.saxs.project_manager import build_project_paths +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) + + +def _new_item_id() -> str: + return uuid.uuid4().hex + + +def _optional_path(text: str) -> Path | None: + stripped = text.strip() + if not stripped: + return None + return Path(stripped).expanduser().resolve() + + +def _required_path(text: str, field_name: str) -> Path: + path = _optional_path(text) + if path is None: + raise ValueError(f"{field_name} is required.") + return path + + +def _normalize_solute_text(raw: str) -> tuple[str, ...]: + values = [token.strip() for token in raw.replace(";", ",").split(",")] + normalized: list[str] = [] + seen: set[str] = set() + for value in values: + if not value: + continue + element = value[:1].upper() + value[1:].lower() + if element in seen: + continue + normalized.append(element) + seen.add(element) + return tuple(normalized) + + +def _solute_text(values: tuple[str, ...]) -> str: + return ", ".join(values) + + +def _suggest_project_dir(frames_dir: Path | None) -> Path: + if frames_dir is not None: + return frames_dir.parent / f"{frames_dir.name}_pdfbatch" + return Path.home() / "saxshell_pdf_batch_project" + + +def _project_reference_text(project_dir: Path | None) -> str: + if project_dir is None: + return "Project reference: choose a SAXSShell project folder." + project_file = build_project_paths(project_dir).project_file + if project_file.is_file(): + return f"Project reference: {project_file}" + return ( + "Project reference: " + f"{project_file} will be used when the calculation is saved." + ) + + +def _box_text(box_dimensions: tuple[float, float, float]) -> str: + return " x ".join(f"{value:.3f}" for value in box_dimensions) + " A" + + +def _project_path(value: object, project_dir: Path) -> Path | None: + if value is None: + return None + text = str(value).strip() + if not text: + return None + path = Path(text).expanduser() + if not path.is_absolute(): + path = project_dir / path + return path.resolve() + + +def _coerce_optional_float(value: object) -> float | None: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _coerce_optional_int(value: object) -> int | None: + numeric = _coerce_optional_float(value) + if numeric is None: + return None + return int(numeric) + + +def _coerce_optional_bool(value: object) -> bool | None: + if isinstance(value, bool): + return value + if value is None: + return None + text = str(value).strip().lower() + if text in {"1", "true", "yes", "on"}: + return True + if text in {"0", "false", "no", "off"}: + return False + return None + + +def _coerce_box_dimensions( + value: object, +) -> tuple[float, float, float] | None: + if isinstance(value, dict): + values = [value.get(key) for key in ("a", "b", "c")] + elif isinstance(value, str): + normalized = ( + value.replace("x", ",") + .replace("X", ",") + .replace(";", ",") + .replace(" ", ",") + ) + values = [part for part in normalized.split(",") if part.strip()] + elif isinstance(value, (list, tuple)): + values = list(value) + else: + return None + if len(values) != 3: + return None + coerced = tuple(_coerce_optional_float(entry) for entry in values) + if any(entry is None for entry in coerced): + return None + return tuple(float(entry) for entry in coerced) # type: ignore[arg-type] + + +def _payload_sources(payload: dict[str, object]) -> list[dict[str, object]]: + sources: list[dict[str, object]] = [] + for key in ( + "debyer_pdf_settings", + "pdf_debyer_settings", + "debyer_settings", + "pdf_settings", + ): + value = payload.get(key) + if isinstance(value, dict): + sources.append(value) + sources.append(payload) + return sources + + +def _payload_value( + sources: list[dict[str, object]], + keys: tuple[str, ...], +) -> object | None: + for source in sources: + for key in keys: + value = source.get(key) + if value is not None and str(value).strip(): + return value + return None + + +def _load_project_payload(project_dir: Path) -> dict[str, object]: + project_file = build_project_paths(project_dir).project_file + if not project_file.is_file(): + return {} + try: + payload = json.loads(project_file.read_text(encoding="utf-8")) + except Exception: + return {} + return payload if isinstance(payload, dict) else {} + + +def _load_latest_project_debyer_defaults( + project_dir: Path, +) -> DebyerPDFBatchItem | None: + for summary in list_saved_debyer_calculations(project_dir): + try: + calculation = load_debyer_calculation(summary.calculation_dir) + except Exception: + continue + return DebyerPDFBatchItem( + item_id=_new_item_id(), + project_dir=project_dir, + frames_dir=calculation.frames_dir, + filename_prefix=calculation.filename_prefix, + mode=calculation.mode, + from_value=calculation.from_value, + to_value=calculation.to_value, + step_value=calculation.step_value, + box_dimensions=calculation.box_dimensions, + atom_count=calculation.atom_count, + store_frame_outputs=calculation.store_frame_outputs, + solute_elements=calculation.solute_elements, + max_parallel_jobs=calculation.parallel_jobs, + ) + return None + + +def _queue_item_from_project_defaults( + project_dir: Path, + *, + item_id: str | None = None, + frames_dir_override: Path | None = None, +) -> DebyerPDFBatchItem: + resolved_project_dir = Path(project_dir).expanduser().resolve() + item = _load_latest_project_debyer_defaults(resolved_project_dir) + if item is None: + item = DebyerPDFBatchItem( + item_id=item_id or _new_item_id(), + project_dir=resolved_project_dir, + ) + else: + item = replace(item, item_id=item_id or _new_item_id()) + + payload = _load_project_payload(resolved_project_dir) + sources = _payload_sources(payload) + + frames_value = _payload_value( + sources, + ( + "frames_dir", + "xyz_frames_dir", + "xyz_file_path", + "xyz_path", + "debyer_frames_dir", + "pdf_frames_dir", + ), + ) + frames_dir = _project_path(frames_value, resolved_project_dir) + if frames_dir_override is not None: + frames_dir = Path(frames_dir_override).expanduser().resolve() + + filename_prefix = str( + _payload_value( + sources, + ("filename_prefix", "pdf_filename_prefix", "debyer_prefix"), + ) + or item.filename_prefix + ) + mode = str( + _payload_value(sources, ("mode", "pdf_mode", "debyer_mode")) + or item.mode + ) + from_value = _coerce_optional_float( + _payload_value( + sources, + ( + "from_value", + "r_min", + "r_range_min", + "pdf_from_value", + "debyer_from_value", + ), + ) + ) + to_value = _coerce_optional_float( + _payload_value( + sources, + ( + "to_value", + "r_max", + "r_range_max", + "pdf_to_value", + "debyer_to_value", + ), + ) + ) + step_value = _coerce_optional_float( + _payload_value( + sources, + ( + "step_value", + "r_step", + "r_range_step", + "pdf_step_value", + "debyer_step_value", + ), + ) + ) + box_dimensions = _coerce_box_dimensions( + _payload_value( + sources, + ( + "box_dimensions", + "bounding_box", + "pdf_box_dimensions", + "debyer_box_dimensions", + ), + ) + ) + atom_count = _coerce_optional_int( + _payload_value( + sources, + ("atom_count", "pdf_atom_count", "debyer_atom_count"), + ) + ) + solute_value = _payload_value( + sources, + ( + "solute_elements", + "pdf_solute_elements", + "debyer_solute_elements", + ), + ) + if isinstance(solute_value, str): + solute_elements = _normalize_solute_text(solute_value) + elif isinstance(solute_value, (list, tuple, set)): + solute_elements = _normalize_solute_text( + ",".join(str(value) for value in solute_value) + ) + else: + solute_elements = item.solute_elements + store_frame_outputs = _coerce_optional_bool( + _payload_value( + sources, + ("store_frame_outputs", "pdf_store_frame_outputs"), + ) + ) + parallel_jobs = _coerce_optional_int( + _payload_value( + sources, + ("max_parallel_jobs", "parallel_jobs", "pdf_parallel_jobs"), + ) + ) + + return replace( + item, + project_dir=resolved_project_dir, + frames_dir=frames_dir or item.frames_dir, + filename_prefix=filename_prefix.strip() or item.filename_prefix, + mode=mode if mode in SUPPORTED_DEBYER_MODES else item.mode, + from_value=item.from_value if from_value is None else from_value, + to_value=item.to_value if to_value is None else to_value, + step_value=item.step_value if step_value is None else step_value, + box_dimensions=box_dimensions or item.box_dimensions, + atom_count=item.atom_count if atom_count is None else atom_count, + store_frame_outputs=( + item.store_frame_outputs + if store_frame_outputs is None + else store_frame_outputs + ), + solute_elements=solute_elements, + max_parallel_jobs=( + item.max_parallel_jobs + if parallel_jobs is None + else max(int(parallel_jobs), 1) + ), + ) + + +def _coerce_r_range_maximum_for_box( + r_max: float, + box_dimensions: tuple[float, float, float], +) -> tuple[float, bool]: + if any(component <= 0.0 for component in box_dimensions): + return r_max, False + allowed_r_max = min(box_dimensions) * 0.5 + if r_max <= allowed_r_max: + return r_max, False + return allowed_r_max, True + + +def _choose_existing_directories( + parent: QWidget, + *, + title: str, + start_dir: str | Path, +) -> tuple[Path, ...]: + dialog = QFileDialog(parent, title, str(start_dir)) + dialog.setFileMode(QFileDialog.FileMode.Directory) + dialog.setOption(QFileDialog.Option.ShowDirsOnly, True) + dialog.setOption(QFileDialog.Option.DontUseNativeDialog, True) + for view in dialog.findChildren(QListView) + dialog.findChildren( + QTreeView + ): + view.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + if dialog.exec() != int(QFileDialog.DialogCode.Accepted): + return () + return tuple( + Path(path).expanduser().resolve() for path in dialog.selectedFiles() + ) + + +@dataclass(slots=True) +class DebyerPDFBatchItem: + item_id: str + project_dir: Path | None = None + frames_dir: Path | None = None + filename_prefix: str = "debyer_pdf" + mode: str = "PDF" + from_value: float = 0.5 + to_value: float = 15.0 + step_value: float = 0.01 + box_dimensions: tuple[float, float, float] = (0.0, 0.0, 0.0) + atom_count: int = 0 + store_frame_outputs: bool = False + solute_elements: tuple[str, ...] = () + max_parallel_jobs: int = default_parallel_debyer_jobs() + + def display_name(self) -> str: + if self.project_dir is not None: + return self.project_dir.name + if self.frames_dir is not None: + return self.frames_dir.name + return "New PDF calculation" + + def to_settings(self) -> DebyerPDFSettings: + frames_dir = self.frames_dir + if frames_dir is None: + raise ValueError("Select an XYZ frames folder.") + project_dir = self.project_dir or _suggest_project_dir(frames_dir) + if self.atom_count <= 0: + raise ValueError("Atom count must be positive.") + if any(component <= 0.0 for component in self.box_dimensions): + raise ValueError("All bounding-box dimensions must be positive.") + return DebyerPDFSettings( + project_dir=project_dir, + frames_dir=frames_dir, + filename_prefix=self.filename_prefix.strip() or "debyer_pdf", + mode=self.mode, + from_value=float(self.from_value), + to_value=float(self.to_value), + step_value=float(self.step_value), + box_dimensions=tuple( + float(component) for component in self.box_dimensions + ), + atom_count=int(self.atom_count), + store_frame_outputs=bool(self.store_frame_outputs), + solute_elements=tuple(self.solute_elements), + max_parallel_jobs=int(self.max_parallel_jobs), + ) + + +@dataclass(slots=True, frozen=True) +class DebyerPDFExistingPartialsJob: + project_dir: Path + solute_elements: tuple[str, ...] = () + + +class DebyerPDFBatchItemWidget(QFrame): + settings_changed = Signal(str) + remove_requested = Signal(str) + duplicate_requested = Signal(str) + + def __init__( + self, + item: DebyerPDFBatchItem, + *, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._item = item + self._loading = False + self._locked = False + self._selected = False + self._append_grouped_mode = False + self._build_ui() + self._load_item(item) + self._set_settings_visible(False) + + @property + def item_id(self) -> str: + return self._item.item_id + + def item(self) -> DebyerPDFBatchItem: + return self._item + + def set_locked(self, locked: bool) -> None: + self._locked = bool(locked) + self.settings_group.setEnabled(not locked) + self.remove_button.setEnabled(not locked) + self.duplicate_button.setEnabled(not locked) + self._refresh_setting_widget_states() + + def set_append_grouped_mode(self, enabled: bool) -> None: + self._append_grouped_mode = bool(enabled) + self._refresh_setting_widget_states() + if enabled: + self.inspection_summary_label.setText( + "Append mode uses the project folder and solute elements to " + "update existing saved Debyer calculations. Full Debyer " + "settings are ignored." + ) + else: + self.inspection_summary_label.setText( + "Inspect the XYZ frames folder to detect atoms, solutes, and " + "box." + ) + + def set_status(self, message: str) -> None: + self.status_label.setText(message) + + def set_progress(self, processed: int, total: int) -> None: + self.progress_bar.setRange(0, max(int(total), 1)) + self.progress_bar.setValue(max(int(processed), 0)) + + def set_selected(self, selected: bool) -> None: + self._selected = bool(selected) + self.header_frame.setProperty("selected", self._selected) + self.header_frame.setStyleSheet( + "QFrame#DebyerBatchItemHeader {" + + ( + "background-color: #dce8f7; " "border: 1px solid #8fb0d7;" + if self._selected + else "background-color: #f6f8fb; " "border: 1px solid #cfd7e3;" + ) + + "border-radius: 5px;}" + ) + + def collect_item(self) -> DebyerPDFBatchItem: + frames_dir = _optional_path(self.frames_dir_edit.text()) + project_dir = _optional_path(self.project_dir_edit.text()) + if project_dir is None and frames_dir is not None: + project_dir = _suggest_project_dir(frames_dir) + self.project_dir_edit.setText(str(project_dir)) + box_dimensions = ( + float(self.box_a_edit.text().strip()), + float(self.box_b_edit.text().strip()), + float(self.box_c_edit.text().strip()), + ) + to_value, changed = _coerce_r_range_maximum_for_box( + float(self.to_edit.text().strip()), + box_dimensions, + ) + if changed: + self.to_edit.setText(f"{to_value:g}") + self.status_label.setText( + "r max adjusted to half of the minimum box dimension." + ) + self._item = DebyerPDFBatchItem( + item_id=self._item.item_id, + project_dir=project_dir, + frames_dir=frames_dir, + filename_prefix=self.filename_prefix_edit.text().strip() + or "debyer_pdf", + mode=self.mode_combo.currentText(), + from_value=float(self.from_edit.text().strip()), + to_value=to_value, + step_value=float(self.step_edit.text().strip()), + box_dimensions=box_dimensions, + atom_count=int(float(self.atom_count_edit.text().strip())), + store_frame_outputs=self.store_frame_outputs_checkbox.isChecked(), + solute_elements=_normalize_solute_text( + self.solute_elements_edit.text() + ), + max_parallel_jobs=int(self.parallel_jobs_spin.value()), + ) + self._refresh_header() + self._refresh_project_reference() + self._refresh_rho0_label() + return self._item + + def settings(self) -> DebyerPDFSettings: + return self.collect_item().to_settings() + + def existing_partials_job(self) -> DebyerPDFExistingPartialsJob: + frames_dir = _optional_path(self.frames_dir_edit.text()) + project_dir = _optional_path(self.project_dir_edit.text()) + if project_dir is None and frames_dir is not None: + project_dir = _suggest_project_dir(frames_dir) + self.project_dir_edit.setText(str(project_dir)) + if project_dir is None: + raise ValueError( + "Select a project folder before appending grouped partials." + ) + solute_elements = _normalize_solute_text( + self.solute_elements_edit.text() + ) + self._item = replace( + self._item, + project_dir=project_dir, + frames_dir=frames_dir, + solute_elements=solute_elements, + ) + self._refresh_header() + self._refresh_project_reference() + return DebyerPDFExistingPartialsJob( + project_dir=project_dir, + solute_elements=solute_elements, + ) + + def inspect_frames(self) -> None: + frames_dir = _required_path( + self.frames_dir_edit.text(), + "XYZ frames folder", + ) + inspection = inspect_frames_dir(frames_dir) + self.frames_dir_edit.setText(str(inspection.frames_dir)) + if self.filename_prefix_edit.text().strip() in {"", "debyer_pdf"}: + self.filename_prefix_edit.setText(inspection.frames_dir.name) + if not self.project_dir_edit.text().strip(): + self.project_dir_edit.setText( + str(_suggest_project_dir(inspection.frames_dir)) + ) + if ( + not self.atom_count_edit.text().strip() + or int(float(self.atom_count_edit.text().strip() or "0")) <= 0 + ): + self.atom_count_edit.setText(str(inspection.atom_count)) + + detected_box = ( + inspection.detected_box_dimensions + if inspection.detected_box_dimensions is not None + else inspection.estimated_box_dimensions + ) + if detected_box is not None: + self._set_box_if_blank_or_zero(detected_box) + to_value, changed = _coerce_r_range_maximum_for_box( + float(self.to_edit.text().strip() or "0"), + detected_box, + ) + if changed: + self.to_edit.setText(f"{to_value:g}") + + inferred_solutes = infer_default_solute_elements( + inspection.element_counts + ) + if inferred_solutes and not self.solute_elements_edit.text().strip(): + self.solute_elements_edit.setText(_solute_text(inferred_solutes)) + + element_summary = ", ".join( + f"{element}{count if count != 1 else ''}" + for element, count in sorted(inspection.element_counts.items()) + ) + solute_summary = ( + _solute_text(inferred_solutes) + if inferred_solutes + else "not inferred" + ) + box_summary = "unknown" + if inspection.detected_box_dimensions is not None: + box_summary = _box_text(inspection.detected_box_dimensions) + if inspection.detected_box_source is not None: + box_summary += f" from {inspection.detected_box_source}" + elif inspection.estimated_box_dimensions is not None: + box_summary = ( + _box_text(inspection.estimated_box_dimensions) + + " estimated from first frame" + ) + self.inspection_summary_label.setText( + f"Detected {inspection.frame_format.upper()} frames: " + f"{len(inspection.frame_paths)} files\n" + f"Elements in first frame: {element_summary or 'unknown'}\n" + f"Default solutes: {solute_summary}\n" + f"Bounding box: {box_summary}" + ) + self.status_label.setText("Settings inspected") + self._on_editor_changed() + + def _build_ui(self) -> None: + self.setFrameShape(QFrame.Shape.StyledPanel) + self.setSizePolicy( + QSizePolicy.Policy.Expanding, + QSizePolicy.Policy.Fixed, + ) + root = QVBoxLayout(self) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + self.header_frame = QFrame() + self.header_frame.setObjectName("DebyerBatchItemHeader") + header = QHBoxLayout(self.header_frame) + header.setContentsMargins(8, 6, 8, 6) + header.setSpacing(8) + self.toggle_button = QToolButton() + self.toggle_button.setCheckable(True) + self.toggle_button.toggled.connect(self._set_settings_visible) + header.addWidget(self.toggle_button) + self.title_label = QLabel("New PDF calculation") + self.title_label.setStyleSheet("font-weight: 600;") + header.addWidget(self.title_label, stretch=1) + self.status_label = QLabel("Ready") + self.status_label.setMinimumWidth(180) + header.addWidget(self.status_label) + self.duplicate_button = QPushButton("Duplicate") + self.duplicate_button.clicked.connect( + lambda: self.duplicate_requested.emit(self.item_id) + ) + header.addWidget(self.duplicate_button) + self.remove_button = QPushButton("Remove") + self.remove_button.clicked.connect( + lambda: self.remove_requested.emit(self.item_id) + ) + header.addWidget(self.remove_button) + root.addWidget(self.header_frame) + self.set_selected(False) + + self.progress_bar = QProgressBar() + self.progress_bar.setRange(0, 1) + self.progress_bar.setValue(0) + self.progress_bar.setFormat("%v / %m frames") + root.addWidget(self.progress_bar) + + self.settings_group = QGroupBox("Debyer Calculation Settings") + root.addWidget(self.settings_group) + form = QFormLayout(self.settings_group) + + project_row = QWidget() + project_layout = QHBoxLayout(project_row) + project_layout.setContentsMargins(0, 0, 0, 0) + self.project_dir_edit = QLineEdit() + self.project_dir_edit.editingFinished.connect(self._on_editor_changed) + project_layout.addWidget(self.project_dir_edit, stretch=1) + project_button = QPushButton("Browse...") + project_button.clicked.connect(self._choose_project_dir) + project_layout.addWidget(project_button) + form.addRow("Project folder", project_row) + + self.project_reference_label = QLabel() + self.project_reference_label.setWordWrap(True) + self.project_reference_label.setFrameShape(QFrame.Shape.StyledPanel) + form.addRow("", self.project_reference_label) + + frames_row = QWidget() + frames_layout = QHBoxLayout(frames_row) + frames_layout.setContentsMargins(0, 0, 0, 0) + self.frames_dir_edit = QLineEdit() + self.frames_dir_edit.editingFinished.connect(self._inspect_from_edit) + frames_layout.addWidget(self.frames_dir_edit, stretch=1) + self.frames_button = QPushButton("Browse...") + self.frames_button.clicked.connect(self._choose_frames_dir) + frames_layout.addWidget(self.frames_button) + self.inspect_button = QPushButton("Inspect") + self.inspect_button.clicked.connect(self._inspect_from_button) + frames_layout.addWidget(self.inspect_button) + form.addRow("XYZ frames folder", frames_row) + + self.inspection_summary_label = QLabel( + "Inspect the XYZ frames folder to detect atoms, solutes, and box." + ) + self.inspection_summary_label.setWordWrap(True) + self.inspection_summary_label.setFrameShape(QFrame.Shape.StyledPanel) + form.addRow("", self.inspection_summary_label) + + self.filename_prefix_edit = QLineEdit("debyer_pdf") + self.filename_prefix_edit.editingFinished.connect( + self._on_editor_changed + ) + form.addRow("Output prefix", self.filename_prefix_edit) + + self.mode_combo = QComboBox() + for mode in SUPPORTED_DEBYER_MODES: + self.mode_combo.addItem(mode) + self.mode_combo.currentIndexChanged.connect(self._on_editor_changed) + form.addRow("Mode", self.mode_combo) + + range_widget = QWidget() + range_layout = QGridLayout(range_widget) + range_layout.setContentsMargins(0, 0, 0, 0) + self.from_edit = QLineEdit("0.5") + self.to_edit = QLineEdit("15") + self.step_edit = QLineEdit("0.01") + for widget in (self.from_edit, self.to_edit, self.step_edit): + widget.editingFinished.connect(self._on_editor_changed) + range_layout.addWidget(QLabel("from"), 0, 0) + range_layout.addWidget(self.from_edit, 0, 1) + range_layout.addWidget(QLabel("to"), 0, 2) + range_layout.addWidget(self.to_edit, 0, 3) + range_layout.addWidget(QLabel("step"), 0, 4) + range_layout.addWidget(self.step_edit, 0, 5) + form.addRow("r-range (A)", range_widget) + + box_widget = QWidget() + box_layout = QGridLayout(box_widget) + box_layout.setContentsMargins(0, 0, 0, 0) + self.box_a_edit = QLineEdit() + self.box_b_edit = QLineEdit() + self.box_c_edit = QLineEdit() + for widget in (self.box_a_edit, self.box_b_edit, self.box_c_edit): + widget.editingFinished.connect(self._on_editor_changed) + box_layout.addWidget(QLabel("a"), 0, 0) + box_layout.addWidget(self.box_a_edit, 0, 1) + box_layout.addWidget(QLabel("b"), 0, 2) + box_layout.addWidget(self.box_b_edit, 0, 3) + box_layout.addWidget(QLabel("c"), 0, 4) + box_layout.addWidget(self.box_c_edit, 0, 5) + form.addRow("Bounding box (A)", box_widget) + + self.atom_count_edit = QLineEdit() + self.atom_count_edit.editingFinished.connect(self._on_editor_changed) + form.addRow("Atom count", self.atom_count_edit) + self.rho0_label = QLabel( + "rho0 will be computed from the atom count and box." + ) + self.rho0_label.setWordWrap(True) + form.addRow("", self.rho0_label) + + self.solute_elements_edit = QLineEdit() + self.solute_elements_edit.setPlaceholderText("Optional, e.g. Pb, I") + self.solute_elements_edit.setToolTip( + "Defines solute atoms for grouped partial traces. In append " + "mode, edit this value before running to rebuild the grouped " + "columns with a different solute definition." + ) + self.solute_elements_edit.editingFinished.connect( + self._on_editor_changed + ) + form.addRow("Solute elements", self.solute_elements_edit) + + self.store_frame_outputs_checkbox = QCheckBox( + "Store per-frame Debyer output files" + ) + self.store_frame_outputs_checkbox.toggled.connect( + self._on_editor_changed + ) + form.addRow("", self.store_frame_outputs_checkbox) + + self.parallel_jobs_spin = QSpinBox() + self.parallel_jobs_spin.setRange(1, 64) + self.parallel_jobs_spin.setValue(default_parallel_debyer_jobs()) + self.parallel_jobs_spin.valueChanged.connect(self._on_editor_changed) + form.addRow("Parallel Debyer jobs", self.parallel_jobs_spin) + self._full_calculation_widgets = ( + self.frames_dir_edit, + self.frames_button, + self.inspect_button, + self.filename_prefix_edit, + self.mode_combo, + self.from_edit, + self.to_edit, + self.step_edit, + self.box_a_edit, + self.box_b_edit, + self.box_c_edit, + self.atom_count_edit, + self.store_frame_outputs_checkbox, + self.parallel_jobs_spin, + ) + + def _load_item(self, item: DebyerPDFBatchItem) -> None: + self._loading = True + self.project_dir_edit.setText( + "" if item.project_dir is None else str(item.project_dir) + ) + self.frames_dir_edit.setText( + "" if item.frames_dir is None else str(item.frames_dir) + ) + self.filename_prefix_edit.setText(item.filename_prefix) + self.mode_combo.setCurrentText(item.mode) + self.from_edit.setText(f"{item.from_value:g}") + self.to_edit.setText(f"{item.to_value:g}") + self.step_edit.setText(f"{item.step_value:g}") + self.box_a_edit.setText( + "" + if item.box_dimensions[0] <= 0.0 + else f"{item.box_dimensions[0]:g}" + ) + self.box_b_edit.setText( + "" + if item.box_dimensions[1] <= 0.0 + else f"{item.box_dimensions[1]:g}" + ) + self.box_c_edit.setText( + "" + if item.box_dimensions[2] <= 0.0 + else f"{item.box_dimensions[2]:g}" + ) + self.atom_count_edit.setText( + "" if item.atom_count <= 0 else str(item.atom_count) + ) + self.solute_elements_edit.setText(_solute_text(item.solute_elements)) + self.store_frame_outputs_checkbox.setChecked(item.store_frame_outputs) + self.parallel_jobs_spin.setValue(item.max_parallel_jobs) + self._loading = False + self._refresh_header() + self._refresh_project_reference() + self._refresh_rho0_label() + + def _set_settings_visible(self, visible: bool) -> None: + self.settings_group.setVisible(bool(visible)) + self.toggle_button.setChecked(bool(visible)) + self.toggle_button.setText("Hide Settings" if visible else "Settings") + parent_item = self._list_item() + if parent_item is not None: + parent_item.setSizeHint(self.sizeHint()) + + def _list_item(self) -> QListWidgetItem | None: + parent = self.parent() + while parent is not None and not isinstance(parent, QListWidget): + parent = parent.parent() + if not isinstance(parent, QListWidget): + return None + for row in range(parent.count()): + list_item = parent.item(row) + if parent.itemWidget(list_item) is self: + return list_item + return None + + def _choose_project_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select SAXSShell project folder", + self.project_dir_edit.text().strip() or str(Path.home()), + ) + if not selected: + return + self.project_dir_edit.setText(selected) + self._on_editor_changed() + + def _choose_frames_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select XYZ frames folder", + self.frames_dir_edit.text().strip() + or self.project_dir_edit.text().strip() + or str(Path.home()), + ) + if not selected: + return + self.frames_dir_edit.setText(selected) + self._inspect_from_button() + + def _inspect_from_edit(self) -> None: + if not self.frames_dir_edit.text().strip(): + self._on_editor_changed() + return + try: + self.inspect_frames() + except Exception as exc: + self.inspection_summary_label.setText(str(exc)) + self.status_label.setText("Inspection failed") + self._on_editor_changed() + + def _inspect_from_button(self) -> None: + try: + self.inspect_frames() + except Exception as exc: + QMessageBox.warning(self, "Unable to inspect frames", str(exc)) + self.inspection_summary_label.setText(str(exc)) + self.status_label.setText("Inspection failed") + self._on_editor_changed() + + def _set_box_if_blank_or_zero( + self, + box_dimensions: tuple[float, float, float], + ) -> None: + for line_edit, value in zip( + (self.box_a_edit, self.box_b_edit, self.box_c_edit), + box_dimensions, + ): + text = line_edit.text().strip() + try: + current = float(text) if text else 0.0 + except ValueError: + current = 0.0 + if current <= 0.0: + line_edit.setText(f"{float(value):g}") + + def _on_editor_changed(self) -> None: + if self._loading: + return + try: + self.collect_item() + self.status_label.setText("Ready") + except Exception: + self._refresh_project_reference() + self._refresh_header() + self._refresh_rho0_label() + self.settings_changed.emit(self.item_id) + + def _refresh_header(self) -> None: + self.title_label.setText(self._item.display_name()) + + def _refresh_project_reference(self) -> None: + project_dir = _optional_path(self.project_dir_edit.text()) + self.project_reference_label.setText( + _project_reference_text(project_dir) + ) + + def _refresh_rho0_label(self) -> None: + try: + atom_count = int(float(self.atom_count_edit.text().strip())) + box = ( + float(self.box_a_edit.text().strip()), + float(self.box_b_edit.text().strip()), + float(self.box_c_edit.text().strip()), + ) + rho0 = calculate_number_density(atom_count, box) + except Exception: + self.rho0_label.setText( + "rho0 will be computed from the atom count and box." + ) + return + self.rho0_label.setText(f"rho0 = {rho0:.6g} atoms/A^3") + + def _refresh_setting_widget_states(self) -> None: + if not hasattr(self, "_full_calculation_widgets"): + return + enabled = not self._append_grouped_mode and not self._locked + for widget in self._full_calculation_widgets: + widget.setEnabled(enabled) + + +class DebyerPDFBatchWorker(QObject): + item_started = Signal(str, int, int) + item_progress = Signal(str, int, int, str) + item_finished = Signal(str, object) + item_failed = Signal(str, str) + log = Signal(str) + status = Signal(str) + finished = Signal(object) + failed = Signal(str, str) + + def __init__( + self, + queue_entries: list[tuple[str, DebyerPDFSettings]], + *, + debyer_executable: str | Path | None = None, + ) -> None: + super().__init__() + self.queue_entries = list(queue_entries) + self.debyer_executable = debyer_executable + self._cancel_requested = threading.Event() + + def request_cancel(self) -> None: + self._cancel_requested.set() + + @Slot() + def run(self) -> None: + results: list[DebyerPDFCalculation] = [] + total_items = len(self.queue_entries) + for index, (item_id, settings) in enumerate( + self.queue_entries, + start=1, + ): + if self._cancel_requested.is_set(): + self.log.emit("Batch queue stopped before the next project.") + break + label = settings.filename_prefix or settings.project_dir.name + self.item_started.emit(item_id, index, total_items) + self.status.emit( + f"Running {index}/{total_items}: {settings.project_dir.name}" + ) + self.log.emit( + f"Starting {index}/{total_items}: {settings.project_dir}" + ) + try: + workflow = DebyerPDFWorkflow( + settings, + debyer_executable=self.debyer_executable, + ) + result = workflow.run( + progress_callback=( + lambda processed, total, message, item_id=item_id: self.item_progress.emit( + item_id, + processed, + total, + message, + ) + ), + log_callback=lambda message, label=label: self.log.emit( + f"[{label}] {message}" + ), + status_callback=lambda message, label=label: self.status.emit( + f"{label}: {message}" + ), + cancel_callback=self._cancel_requested.is_set, + ) + except Exception as exc: + message = str(exc) + self.item_failed.emit(item_id, message) + self.failed.emit(item_id, message) + return + results.append(result) + self.item_finished.emit(item_id, result) + if result.is_partial_average or self._cancel_requested.is_set(): + self.log.emit("Batch queue stopped after saving current work.") + break + self.status.emit("PDF batch queue finished") + self.finished.emit(results) + + +class DebyerPDFExistingPartialsWorker(QObject): + item_started = Signal(str, int, int) + item_progress = Signal(str, int, int, str) + item_finished = Signal(str, object) + item_failed = Signal(str, str) + log = Signal(str) + status = Signal(str) + finished = Signal(object) + failed = Signal(str, str) + + def __init__( + self, + queue_entries: list[tuple[str, DebyerPDFExistingPartialsJob]], + ) -> None: + super().__init__() + self.queue_entries = list(queue_entries) + self._cancel_requested = threading.Event() + + def request_cancel(self) -> None: + self._cancel_requested.set() + + @Slot() + def run(self) -> None: + results: list[DebyerPDFCalculation] = [] + total_items = len(self.queue_entries) + for index, (item_id, job) in enumerate( + self.queue_entries, + start=1, + ): + if self._cancel_requested.is_set(): + self.log.emit("Batch queue stopped before the next project.") + break + self.item_started.emit(item_id, index, total_items) + self.status.emit( + f"Updating {index}/{total_items}: {job.project_dir.name}" + ) + try: + updated = self._update_project(item_id, job) + except Exception as exc: + message = str(exc) + self.item_failed.emit(item_id, message) + self.failed.emit(item_id, message) + return + results.extend(updated) + self.item_finished.emit(item_id, updated) + self.status.emit("Grouped partial column update finished") + self.finished.emit(results) + + def _update_project( + self, + item_id: str, + job: DebyerPDFExistingPartialsJob, + ) -> list[DebyerPDFCalculation]: + project_dir = Path(job.project_dir).expanduser().resolve() + if not project_dir.is_dir(): + raise ValueError( + f"The project folder does not exist: {project_dir}" + ) + summaries = list_saved_debyer_calculations(project_dir) + if not summaries: + raise ValueError( + "No saved Debyer calculations were found in " f"{project_dir}." + ) + self.log.emit( + f"Updating {len(summaries)} saved Debyer calculation(s) in " + f"{project_dir}" + ) + updated: list[DebyerPDFCalculation] = [] + total = len(summaries) + for processed, summary in enumerate(summaries, start=1): + if self._cancel_requested.is_set(): + self.log.emit( + f"Stopped before updating {summary.calculation_dir.name}." + ) + break + calculation = load_debyer_calculation(summary.calculation_dir) + solute_elements = ( + job.solute_elements or calculation.solute_elements + ) + if not solute_elements: + raise ValueError( + "Solute elements are required to append grouped partial " + f"columns for {calculation.calculation_dir}." + ) + calculation = replace( + calculation, + solute_elements=solute_elements, + target_peak_markers={}, + ) + rewrite_debyer_calculation_output(calculation) + write_debyer_calculation_metadata(calculation) + updated.append(calculation) + self.item_progress.emit( + item_id, + processed, + total, + f"Updated {processed}/{total}: {summary.filename_prefix}", + ) + self.log.emit( + "Appended grouped columns to " + f"{calculation.averaged_output_file}" + ) + return updated + + +class DebyerPDFBatchQueueWindow(QMainWindow): + """Queue Debyer PDF calculations for multiple projects.""" + + def __init__( + self, + initial_project_dir: str | Path | None = None, + *, + initial_frames_dir: str | Path | None = None, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._widgets_by_id: dict[str, DebyerPDFBatchItemWidget] = {} + self._run_thread: QThread | None = None + self._run_worker: ( + DebyerPDFBatchWorker | DebyerPDFExistingPartialsWorker | None + ) = None + self._initial_project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + self._initial_frames_dir = ( + None + if initial_frames_dir is None + else Path(initial_frames_dir).expanduser().resolve() + ) + self._build_ui() + self._refresh_runtime_status() + if ( + self._initial_project_dir is not None + or self._initial_frames_dir is not None + ): + initial_item = ( + _queue_item_from_project_defaults( + self._initial_project_dir, + frames_dir_override=self._initial_frames_dir, + ) + if self._initial_project_dir is not None + else DebyerPDFBatchItem( + item_id=_new_item_id(), + frames_dir=self._initial_frames_dir, + filename_prefix=( + self._initial_frames_dir.name + if self._initial_frames_dir is not None + else "debyer_pdf" + ), + ) + ) + self.add_queue_item( + initial_item, + auto_inspect=self._initial_frames_dir is not None, + ) + + def closeEvent(self, event) -> None: + if self._run_thread is not None and self._run_thread.isRunning(): + self._request_cancel() + self.hide() + while ( + self._run_thread is not None and self._run_thread.isRunning() + ): + QApplication.processEvents() + if self._run_thread is not None: + self._run_thread.wait(50) + event.accept() + return + super().closeEvent(event) + + def add_queue_item( + self, + item: DebyerPDFBatchItem | None = None, + *, + auto_inspect: bool = False, + ) -> DebyerPDFBatchItemWidget: + resolved_item = item or DebyerPDFBatchItem(item_id=_new_item_id()) + list_item = QListWidgetItem() + list_item.setData(Qt.ItemDataRole.UserRole, resolved_item.item_id) + self.queue_list.addItem(list_item) + widget = DebyerPDFBatchItemWidget( + resolved_item, parent=self.queue_list + ) + widget.settings_changed.connect(self._on_item_settings_changed) + widget.remove_requested.connect(self._remove_item) + widget.duplicate_requested.connect(self._duplicate_item) + self._widgets_by_id[resolved_item.item_id] = widget + widget.set_append_grouped_mode(self._is_append_grouped_mode()) + list_item.setSizeHint(widget.sizeHint()) + self.queue_list.setItemWidget(list_item, widget) + self.queue_list.setCurrentItem(list_item) + self._refresh_order_labels() + if auto_inspect: + try: + widget.inspect_frames() + except Exception as exc: + widget.inspection_summary_label.setText(str(exc)) + widget.set_status("Inspection failed") + return widget + + def queue_settings_in_order(self) -> list[tuple[str, DebyerPDFSettings]]: + entries: list[tuple[str, DebyerPDFSettings]] = [] + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id[item_id] + entries.append((item_id, widget.settings())) + return entries + + def existing_partials_jobs_in_order( + self, + ) -> list[tuple[str, DebyerPDFExistingPartialsJob]]: + entries: list[tuple[str, DebyerPDFExistingPartialsJob]] = [] + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id[item_id] + entries.append((item_id, widget.existing_partials_job())) + return entries + + def _build_ui(self) -> None: + self.setWindowTitle("SAXSShell PDF Batch Queue") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1120, 860) + + central = QWidget() + root = QVBoxLayout(central) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + self.runtime_status_label = QLabel("Checking Debyer runtime...") + self.runtime_status_label.setWordWrap(True) + self.runtime_status_label.setFrameShape(QFrame.Shape.StyledPanel) + root.addWidget(self.runtime_status_label) + + controls = QHBoxLayout() + self.add_current_button = QPushButton("Add Current Project") + self.add_current_button.clicked.connect(self._add_current_project) + controls.addWidget(self.add_current_button) + self.add_project_button = QPushButton("Add Projects...") + self.add_project_button.clicked.connect(self._choose_project_to_add) + controls.addWidget(self.add_project_button) + self.add_frames_button = QPushButton("Add XYZ Frame Folders...") + self.add_frames_button.clicked.connect(self._choose_frames_to_add) + controls.addWidget(self.add_frames_button) + controls.addStretch(1) + root.addLayout(controls) + + mode_row = QHBoxLayout() + mode_row.addWidget(QLabel("Queue mode")) + self.queue_mode_combo = QComboBox() + self.queue_mode_combo.addItem( + "Run full Debyer calculations", + "calculate", + ) + self.queue_mode_combo.addItem( + "Append grouped partial columns only", + "append_grouped", + ) + self.queue_mode_combo.currentIndexChanged.connect( + self._on_queue_mode_changed + ) + mode_row.addWidget(self.queue_mode_combo) + self.queue_mode_status_label = QLabel( + "Runs Debyer for each queue item in order." + ) + self.queue_mode_status_label.setWordWrap(True) + mode_row.addWidget(self.queue_mode_status_label, stretch=1) + root.addLayout(mode_row) + + self.queue_list = QListWidget() + self.queue_list.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.queue_list.setDragDropMode( + QAbstractItemView.DragDropMode.InternalMove + ) + self.queue_list.setDefaultDropAction(Qt.DropAction.MoveAction) + self.queue_list.setAlternatingRowColors(True) + self.queue_list.setStyleSheet( + "QListWidget::item:selected { background: transparent; }" + "QListWidget::item:hover { background: transparent; }" + "QListWidget::item { margin: 3px; }" + ) + self.queue_list.model().rowsMoved.connect(self._refresh_order_labels) + self.queue_list.itemSelectionChanged.connect( + self._refresh_item_selection_styles + ) + root.addWidget(self.queue_list, stretch=1) + + run_group = QGroupBox("Execute Queue") + run_layout = QVBoxLayout(run_group) + run_buttons = QHBoxLayout() + self.run_button = QPushButton("Run Complete Queue") + self.run_button.clicked.connect(self._start_queue) + run_buttons.addWidget(self.run_button) + self.cancel_button = QPushButton("Stop Queue") + self.cancel_button.setEnabled(False) + self.cancel_button.clicked.connect(self._request_cancel) + run_buttons.addWidget(self.cancel_button) + run_buttons.addStretch(1) + run_layout.addLayout(run_buttons) + self.queue_status_label = QLabel("Queue idle") + run_layout.addWidget(self.queue_status_label) + self.console = QTextEdit() + self.console.setReadOnly(True) + self.console.setMinimumHeight(160) + run_layout.addWidget(self.console) + root.addWidget(run_group) + + self.setCentralWidget(central) + self.statusBar().showMessage("Ready") + self._on_queue_mode_changed() + + def _refresh_runtime_status(self) -> None: + status = check_debyer_runtime() + self.runtime_status_label.setText(status.message) + + def _is_append_grouped_mode(self) -> bool: + return self.queue_mode_combo.currentData() == "append_grouped" + + def _on_queue_mode_changed(self, *_args) -> None: + append_mode = self._is_append_grouped_mode() + running = self._run_thread is not None and self._run_thread.isRunning() + self.run_button.setText( + "Append Grouped Partial Columns" + if append_mode + else "Run Complete Queue" + ) + self.queue_mode_status_label.setText( + "Updates existing saved Debyer calculations in each project. " + "Debyer is not launched; project folder and solute elements are " + "used." + if append_mode + else "Runs Debyer for each queue item in order." + ) + self.add_frames_button.setEnabled(not append_mode and not running) + for widget in self._widgets_by_id.values(): + widget.set_append_grouped_mode(append_mode) + + def _add_current_project(self) -> None: + if ( + self._initial_project_dir is None + and self._initial_frames_dir is None + ): + QMessageBox.information( + self, + "No active project", + "The main UI did not provide an active project reference.", + ) + return + self.add_queue_item( + ( + _queue_item_from_project_defaults( + self._initial_project_dir, + frames_dir_override=self._initial_frames_dir, + ) + if self._initial_project_dir is not None + else DebyerPDFBatchItem( + item_id=_new_item_id(), + frames_dir=self._initial_frames_dir, + filename_prefix=( + self._initial_frames_dir.name + if self._initial_frames_dir is not None + else "debyer_pdf" + ), + ) + ), + auto_inspect=self._initial_frames_dir is not None, + ) + + def _choose_project_to_add(self) -> None: + selected_dirs = _choose_existing_directories( + self, + title="Select SAXSShell project folders", + start_dir=self._initial_project_dir or Path.home(), + ) + if not selected_dirs: + return + for project_dir in selected_dirs: + item = _queue_item_from_project_defaults(project_dir) + self.add_queue_item(item, auto_inspect=item.frames_dir is not None) + + def _choose_frames_to_add(self) -> None: + selected_dirs = _choose_existing_directories( + self, + title="Select XYZ frames folders", + start_dir=( + self._initial_frames_dir + or self._initial_project_dir + or Path.home() + ), + ) + if not selected_dirs: + return + for frames_dir in selected_dirs: + self.add_queue_item( + DebyerPDFBatchItem( + item_id=_new_item_id(), + project_dir=self._initial_project_dir, + frames_dir=frames_dir, + filename_prefix=frames_dir.name, + ), + auto_inspect=True, + ) + + def _on_item_settings_changed(self, _item_id: str) -> None: + self._refresh_order_labels() + + def _refresh_order_labels(self, *_args) -> None: + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id.get(item_id) + if widget is None: + continue + widget.title_label.setText( + f"{row + 1}. {widget.item().display_name()}" + ) + list_item.setSizeHint(widget.sizeHint()) + self._refresh_item_selection_styles() + + def _refresh_item_selection_styles(self) -> None: + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_selected(list_item.isSelected()) + + def _remove_item(self, item_id: str) -> None: + if self._run_thread is not None and self._run_thread.isRunning(): + return + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + if str(list_item.data(Qt.ItemDataRole.UserRole)) == item_id: + self.queue_list.takeItem(row) + break + self._widgets_by_id.pop(item_id, None) + self._refresh_order_labels() + + def _duplicate_item(self, item_id: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is None: + return + try: + item = widget.collect_item() + except Exception: + item = widget.item() + self.add_queue_item( + replace( + item, + item_id=_new_item_id(), + filename_prefix=f"{item.filename_prefix}_copy", + ) + ) + + def _set_running(self, running: bool) -> None: + self.add_current_button.setEnabled(not running) + self.add_project_button.setEnabled(not running) + self.add_frames_button.setEnabled( + not running and not self._is_append_grouped_mode() + ) + self.queue_mode_combo.setEnabled(not running) + self.run_button.setEnabled(not running) + self.cancel_button.setEnabled(running) + self.queue_list.setDragEnabled(not running) + self.queue_list.setAcceptDrops(not running) + for widget in self._widgets_by_id.values(): + widget.set_locked(running) + + def _start_queue(self) -> None: + if self.queue_list.count() == 0: + QMessageBox.information( + self, + "PDF batch queue", + "Add at least one project before running the queue.", + ) + return + append_mode = self._is_append_grouped_mode() + try: + entries = ( + self.existing_partials_jobs_in_order() + if append_mode + else self.queue_settings_in_order() + ) + except Exception as exc: + QMessageBox.warning( + self, + "Invalid PDF batch settings", + str(exc), + ) + return + + if not append_mode: + for _item_id, settings in entries: + settings.project_dir.mkdir(parents=True, exist_ok=True) + self.console.clear() + self._set_running(True) + self.queue_status_label.setText( + ( + f"Updating 0/{len(entries)} queued project(s)" + if append_mode + else f"Running 0/{len(entries)} queued calculations" + ) + ) + for widget in self._widgets_by_id.values(): + widget.set_progress(0, 1) + widget.set_status("Queued") + + self._run_thread = QThread(self) + self._run_worker = ( + DebyerPDFExistingPartialsWorker(entries) + if append_mode + else DebyerPDFBatchWorker(entries) + ) + self._run_worker.moveToThread(self._run_thread) + self._run_thread.started.connect(self._run_worker.run) + self._run_worker.item_started.connect(self._on_item_started) + self._run_worker.item_progress.connect(self._on_item_progress) + self._run_worker.item_finished.connect(self._on_item_finished) + self._run_worker.item_failed.connect(self._on_item_failed) + self._run_worker.log.connect(self._append_log) + self._run_worker.status.connect(self._on_status) + self._run_worker.finished.connect(self._on_queue_finished) + self._run_worker.failed.connect(self._on_queue_failed) + self._run_worker.finished.connect(self._run_thread.quit) + self._run_worker.failed.connect(self._run_thread.quit) + self._run_thread.finished.connect(self._cleanup_run_thread) + self._run_thread.finished.connect(self._run_thread.deleteLater) + self._run_thread.start() + + def _request_cancel(self) -> None: + self.cancel_button.setEnabled(False) + self.queue_status_label.setText( + "Stopping queue after the active project finishes" + ) + self._append_log( + "Stop requested; the current project will finish before the " + "queue exits." + ) + if self._run_worker is not None: + self._run_worker.request_cancel() + + def _append_log(self, message: str) -> None: + self.console.append(message) + + def _on_status(self, message: str) -> None: + self.statusBar().showMessage(message) + self.queue_status_label.setText(message) + + def _on_item_started( + self, + item_id: str, + index: int, + total: int, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status( + f"Updating {index}/{total}" + if self._is_append_grouped_mode() + else f"Running {index}/{total}" + ) + widget.set_progress(0, 1) + self.queue_status_label.setText( + ( + f"Updating {index}/{total} queued project(s)" + if self._is_append_grouped_mode() + else f"Running {index}/{total} queued calculations" + ) + ) + + def _on_item_progress( + self, + item_id: str, + processed: int, + total: int, + message: str, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_progress(processed, total) + widget.set_status(message) + + def _on_item_finished(self, item_id: str, result: object) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is None: + return + if isinstance(result, DebyerPDFCalculation): + processed = ( + result.frame_count + if result.processed_frame_count is None + else result.processed_frame_count + ) + widget.set_progress(processed, result.frame_count) + widget.set_status( + "Stopped early" if result.is_partial_average else "Complete" + ) + elif isinstance(result, list): + widget.set_progress(len(result), max(len(result), 1)) + widget.set_status(f"Updated {len(result)} calculation(s)") + else: + widget.set_status("Complete") + + def _on_item_failed(self, item_id: str, message: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status("Failed") + self._append_log(message) + + def _on_queue_finished(self, results: object) -> None: + self._set_running(False) + result_count = len(results) if isinstance(results, list) else 0 + self.queue_status_label.setText( + ( + f"Queue finished: {result_count} calculation(s) updated" + if self._is_append_grouped_mode() + else f"Queue finished: {result_count} calculation(s) saved" + ) + ) + self.statusBar().showMessage("PDF batch queue finished") + + def _on_queue_failed(self, item_id: str, message: str) -> None: + self._set_running(False) + self.queue_status_label.setText("Queue stopped after a failure") + self.statusBar().showMessage("PDF batch queue failed", 5000) + QMessageBox.warning( + self, + "PDF batch queue failed", + f"Queue item {item_id} failed:\n{message}", + ) + + def _cleanup_run_thread(self) -> None: + self._run_thread = None + self._run_worker = None + + +def launch_debyer_pdf_batch_queue_ui( + initial_project_dir: str | Path | None = None, + *, + initial_frames_dir: str | Path | None = None, +) -> int: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication([]) + configure_saxshell_application(app) + window = DebyerPDFBatchQueueWindow( + initial_project_dir=initial_project_dir, + initial_frames_dir=initial_frames_dir, + ) + window.show() + return int(app.exec()) + + +__all__ = [ + "DebyerPDFExistingPartialsJob", + "DebyerPDFExistingPartialsWorker", + "DebyerPDFBatchItem", + "DebyerPDFBatchItemWidget", + "DebyerPDFBatchQueueWindow", + "DebyerPDFBatchWorker", + "launch_debyer_pdf_batch_queue_ui", +] diff --git a/src/saxshell/pdf/debyer/ui/main_window.py b/src/saxshell/pdf/debyer/ui/main_window.py index 6f72277..4843b35 100644 --- a/src/saxshell/pdf/debyer/ui/main_window.py +++ b/src/saxshell/pdf/debyer/ui/main_window.py @@ -2,6 +2,7 @@ import argparse import sys +import threading from dataclasses import replace from pathlib import Path @@ -63,21 +64,70 @@ build_display_traces, check_debyer_runtime, classify_partial_pair, + compute_experimental_fit_metrics, + convert_distribution_values, + default_parallel_debyer_jobs, estimate_partial_peak_markers, find_partial_peak_markers, + fit_coordination_peak_from_r, + infer_default_solute_elements, inspect_frames_dir, list_saved_debyer_calculations, load_debyer_calculation, + rewrite_debyer_calculation_output, write_debyer_calculation_metadata, ) +from saxshell.saxs.project_manager import ( + ExperimentalDataSummary, + load_experimental_data_file, +) from saxshell.saxs.ui.branding import ( configure_saxshell_application, load_saxshell_icon, prepare_saxshell_application_identity, ) +from saxshell.saxs.ui.experimental_data_loader import ( + ExperimentalDataHeaderDialog, +) _OPEN_WINDOWS: list["DebyerPDFMainWindow"] = [] +_GROUP_TRACE_DEFAULT_COLORS = { + "group:solute-solute": "#cc79a7", + "group:solute-solvent": "#e69f00", + "group:solvent-solvent": "#009e73", +} + +_SPLITTER_HANDLE_STYLE = """ +QSplitter::handle { + background-color: #c8d1de; + border: 1px solid #9aa8ba; + border-radius: 2px; +} +QSplitter::handle:hover { + background-color: #9fb2ca; + border-color: #6f83a0; +} +QSplitter::handle:pressed { + background-color: #8299b8; + border-color: #536b8c; +} +""" + + +def _configure_resize_splitter( + splitter: QSplitter, + *, + handle_width: int, + tooltip: str, +) -> None: + splitter.setChildrenCollapsible(False) + splitter.setHandleWidth(handle_width) + splitter.setOpaqueResize(True) + splitter.setStyleSheet(_SPLITTER_HANDLE_STYLE) + for index in range(1, splitter.count()): + splitter.handle(index).setToolTip(tooltip) + class DebyerPDFWorker(QObject): log = Signal(str) @@ -87,9 +137,30 @@ class DebyerPDFWorker(QObject): finished = Signal(object) failed = Signal(str) - def __init__(self, settings: DebyerPDFSettings) -> None: + def __init__( + self, + settings: DebyerPDFSettings, + *, + preview_enabled: bool = True, + ) -> None: super().__init__() self.settings = settings + self._cancel_requested = threading.Event() + self._preview_enabled = threading.Event() + self._preview_update_requested = threading.Event() + if preview_enabled: + self._preview_enabled.set() + + def request_cancel(self) -> None: + self._cancel_requested.set() + + def set_preview_enabled(self, enabled: bool) -> None: + if enabled: + self._preview_enabled.set() + self._preview_update_requested.set() + else: + self._preview_enabled.clear() + self._preview_update_requested.clear() @Slot() def run(self) -> None: @@ -99,7 +170,9 @@ def run(self) -> None: progress_callback=self._emit_progress, log_callback=self.log.emit, status_callback=self.status.emit, - preview_callback=self.preview.emit, + preview_callback=self._emit_preview, + preview_decision_callback=self._should_emit_preview, + cancel_callback=self._cancel_requested.is_set, ) except Exception as exc: self.failed.emit(str(exc)) @@ -114,6 +187,20 @@ def _emit_progress( ) -> None: self.progress.emit(processed, total, message) + def _should_emit_preview( + self, + _processed: int, + _total: int, + checkpoint_due: bool, + ) -> bool: + if self._preview_update_requested.is_set(): + return True + return bool(checkpoint_due and self._preview_enabled.is_set()) + + def _emit_preview(self, calculation: DebyerPDFCalculation) -> None: + self._preview_update_requested.clear() + self.preview.emit(calculation) + class DebyerPeakEditorDialog(QDialog): def __init__( @@ -295,16 +382,20 @@ def __init__( super().__init__(parent) self._run_thread: QThread | None = None self._run_worker: DebyerPDFWorker | None = None + self._latest_run_preview: DebyerPDFCalculation | None = None self._loaded_summaries: list[DebyerPDFCalculationSummary] = [] self._current_calculation: DebyerPDFCalculation | None = None self._current_traces: list[dict[str, object]] = [] self._trace_visibility: dict[str, bool] = {} self._trace_tag_visibility: dict[str, bool] = {} self._trace_colors: dict[str, str] = {} + self._close_requested_during_run = False self._tag_artist_records: list[dict[str, object]] = [] self._drag_state: dict[str, object] | None = None self._selected_tag: dict[str, object] | None = None + self._experimental_summary: ExperimentalDataSummary | None = None self._build_ui() + self._refresh_experimental_controls() self._delete_tag_shortcut = QShortcut( QKeySequence(Qt.Key.Key_Delete), self, @@ -331,13 +422,18 @@ def __init__( def closeEvent(self, event) -> None: if self._run_thread is not None and self._run_thread.isRunning(): - QMessageBox.warning( - self, - "Debyer PDF", - "Please wait for the current Debyer PDF calculation to " - "finish before closing this window.", + self._request_run_cancel( + "Closing window; stopping Debyer after active frame jobs " + "finish and saving the partial average." ) - event.ignore() + self.hide() + while ( + self._run_thread is not None and self._run_thread.isRunning() + ): + QApplication.processEvents() + if self._run_thread is not None: + self._run_thread.wait(50) + event.accept() return super().closeEvent(event) @@ -350,11 +446,16 @@ def _build_ui(self) -> None: root = QHBoxLayout(central) root.setContentsMargins(8, 8, 8, 8) - splitter = QSplitter(Qt.Orientation.Horizontal) - splitter.addWidget(self._build_left_panel()) - splitter.addWidget(self._build_right_panel()) - splitter.setSizes([460, 980]) - root.addWidget(splitter) + self._main_splitter = QSplitter(Qt.Orientation.Horizontal) + self._main_splitter.addWidget(self._build_left_panel()) + self._main_splitter.addWidget(self._build_right_panel()) + _configure_resize_splitter( + self._main_splitter, + handle_width=14, + tooltip="Drag to resize the setup and results panes.", + ) + self._main_splitter.setSizes([460, 980]) + root.addWidget(self._main_splitter) self.setCentralWidget(central) self.statusBar().showMessage("Ready") @@ -366,6 +467,7 @@ def _build_left_panel(self) -> QWidget: layout.addWidget(self._build_runtime_group()) layout.addWidget(self._build_paths_group()) + layout.addWidget(self._build_experimental_group()) layout.addWidget(self._build_saved_calculations_group()) layout.addWidget(self._build_settings_group()) layout.addWidget(self._build_run_group()) @@ -383,10 +485,15 @@ def _build_right_panel(self) -> QWidget: layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(8) - tabs = QTabWidget() - tabs.addTab(self._build_results_tab(), "Results") - tabs.addTab(self._build_plot_settings_tab(), "Settings") - layout.addWidget(tabs, stretch=1) + self.result_tabs = QTabWidget() + self.result_tabs.addTab(self._build_results_tab(), "Results") + self.result_tabs.addTab( + self._build_shape_function_tab(), + "Shape Function Analysis", + ) + self.result_tabs.addTab(self._build_fit_tab(), "Fit") + self.result_tabs.addTab(self._build_plot_settings_tab(), "Settings") + layout.addWidget(self.result_tabs, stretch=1) return panel def _build_results_tab(self) -> QWidget: @@ -461,10 +568,128 @@ def _build_results_tab(self) -> QWidget: ) table_layout.addWidget(self.trace_table, stretch=1) right_splitter.addWidget(table_container) + _configure_resize_splitter( + right_splitter, + handle_width=12, + tooltip="Drag to resize the plot and trace table.", + ) right_splitter.setSizes([620, 260]) layout.addWidget(right_splitter, stretch=1) return tab + def _build_shape_function_tab(self) -> QWidget: + tab = QWidget() + layout = QVBoxLayout(tab) + layout.setContentsMargins(8, 8, 8, 8) + self.shape_function_status_label = QLabel( + "Shape-function analysis setup will be added here." + ) + self.shape_function_status_label.setWordWrap(True) + self.shape_function_status_label.setFrameShape( + QFrame.Shape.StyledPanel + ) + layout.addWidget(self.shape_function_status_label) + layout.addStretch(1) + return tab + + def _build_fit_tab(self) -> QWidget: + tab = QWidget() + layout = QVBoxLayout(tab) + layout.setContentsMargins(8, 8, 8, 8) + layout.setSpacing(8) + + form = QFormLayout() + self.coordination_fit_trace_combo = QComboBox() + self.coordination_fit_trace_combo.currentIndexChanged.connect( + self._suggest_coordination_fit_window + ) + form.addRow("R(r) trace", self.coordination_fit_trace_combo) + + window_widget = QWidget() + window_layout = QHBoxLayout(window_widget) + window_layout.setContentsMargins(0, 0, 0, 0) + self.coordination_fit_r_min_spin = QDoubleSpinBox() + self.coordination_fit_r_min_spin.setRange(0.0, 100000.0) + self.coordination_fit_r_min_spin.setDecimals(4) + self.coordination_fit_r_min_spin.setSingleStep(0.05) + self.coordination_fit_r_min_spin.setValue(1.0) + self.coordination_fit_r_max_spin = QDoubleSpinBox() + self.coordination_fit_r_max_spin.setRange(0.0, 100000.0) + self.coordination_fit_r_max_spin.setDecimals(4) + self.coordination_fit_r_max_spin.setSingleStep(0.05) + self.coordination_fit_r_max_spin.setValue(4.0) + window_layout.addWidget(QLabel("from")) + window_layout.addWidget(self.coordination_fit_r_min_spin) + window_layout.addWidget(QLabel("to")) + window_layout.addWidget(self.coordination_fit_r_max_spin) + window_layout.addStretch(1) + form.addRow("Fit window (A)", window_widget) + + seed_widget = QWidget() + seed_layout = QHBoxLayout(seed_widget) + seed_layout.setContentsMargins(0, 0, 0, 0) + self.coordination_fit_center_spin = QDoubleSpinBox() + self.coordination_fit_center_spin.setRange(0.0, 100000.0) + self.coordination_fit_center_spin.setDecimals(4) + self.coordination_fit_center_spin.setSingleStep(0.05) + self.coordination_fit_center_spin.setValue(2.5) + self.coordination_fit_sigma_spin = QDoubleSpinBox() + self.coordination_fit_sigma_spin.setRange(0.0001, 100000.0) + self.coordination_fit_sigma_spin.setDecimals(4) + self.coordination_fit_sigma_spin.setSingleStep(0.01) + self.coordination_fit_sigma_spin.setValue(0.2) + seed_layout.addWidget(QLabel("center")) + seed_layout.addWidget(self.coordination_fit_center_spin) + seed_layout.addWidget(QLabel("sigma")) + seed_layout.addWidget(self.coordination_fit_sigma_spin) + seed_layout.addStretch(1) + form.addRow("Initial peak", seed_widget) + layout.addLayout(form) + + button_row = QHBoxLayout() + self.coordination_fit_button = QPushButton("Fit R(r) Peak") + self.coordination_fit_button.clicked.connect( + self._fit_coordination_number + ) + button_row.addWidget(self.coordination_fit_button) + button_row.addStretch(1) + layout.addLayout(button_row) + + self.coordination_fit_status_label = QLabel( + "Load or calculate a Debyer result before fitting R(r)." + ) + self.coordination_fit_status_label.setWordWrap(True) + self.coordination_fit_status_label.setFrameShape( + QFrame.Shape.StyledPanel + ) + layout.addWidget(self.coordination_fit_status_label) + + self.coordination_fit_results_table = QTableWidget(0, 9) + self.coordination_fit_results_table.setHorizontalHeaderLabels( + [ + "Trace", + "r min", + "r max", + "Center", + "Sigma", + "CN", + "Amplitude", + "R^2", + "RMSE", + ] + ) + self.coordination_fit_results_table.verticalHeader().setVisible(False) + self.coordination_fit_results_table.horizontalHeader().setSectionResizeMode( + 0, QHeaderView.ResizeMode.Stretch + ) + for column in range(1, 9): + self.coordination_fit_results_table.horizontalHeader().setSectionResizeMode( + column, + QHeaderView.ResizeMode.ResizeToContents, + ) + layout.addWidget(self.coordination_fit_results_table, stretch=1) + return tab + def _build_plot_settings_tab(self) -> QWidget: tab = QWidget() layout = QVBoxLayout(tab) @@ -612,13 +837,53 @@ def _build_paths_group(self) -> QGroupBox: layout.addRow("Frames folder", frames_row) self.frames_summary_label = QLabel( - "Select a trajectory frame folder containing only .xyz or only .pdb files." + "Select a trajectory frame folder containing .xyz files." ) self.frames_summary_label.setWordWrap(True) self.frames_summary_label.setFrameShape(QFrame.Shape.StyledPanel) layout.addRow("", self.frames_summary_label) return group + def _build_experimental_group(self) -> QGroupBox: + group = QGroupBox("Experimental g(r)") + layout = QFormLayout(group) + + file_row = QWidget() + file_layout = QHBoxLayout(file_row) + file_layout.setContentsMargins(0, 0, 0, 0) + self.experimental_file_edit = QLineEdit() + self.experimental_file_edit.editingFinished.connect( + self._load_experimental_path_from_edit + ) + file_layout.addWidget(self.experimental_file_edit, stretch=1) + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self._choose_experimental_file) + file_layout.addWidget(browse_button) + layout.addRow("Data file", file_row) + + button_row = QHBoxLayout() + self.experimental_columns_button = QPushButton("Columns...") + self.experimental_columns_button.clicked.connect( + self._configure_experimental_columns + ) + button_row.addWidget(self.experimental_columns_button) + self.clear_experimental_button = QPushButton("Clear") + self.clear_experimental_button.clicked.connect( + self._clear_experimental_file + ) + button_row.addWidget(self.clear_experimental_button) + button_row.addStretch(1) + layout.addRow("", button_row) + + self.experimental_status_label = QLabel( + "Optional: load an experimental file with r(A) and g(r) columns." + ) + self.experimental_status_label.setWordWrap(True) + self.experimental_status_label.setFrameShape(QFrame.Shape.StyledPanel) + layout.addRow("", self.experimental_status_label) + self._refresh_experimental_controls() + return group + def _build_saved_calculations_group(self) -> QGroupBox: group = QGroupBox("Saved Calculations") layout = QVBoxLayout(group) @@ -694,19 +959,42 @@ def _build_settings_group(self) -> QGroupBox: ) layout.addRow("Solute elements", self.solute_elements_edit) + self.apply_solute_groups_button = QPushButton("Apply Solute Groups") + self.apply_solute_groups_button.clicked.connect( + self._apply_solute_groups_from_ui + ) + self.apply_solute_groups_button.setToolTip( + "Update the loaded Debyer result with these solute elements and " + "rebuild grouped partial traces without rerunning Debyer." + ) + layout.addRow("", self.apply_solute_groups_button) + self.store_frame_outputs_checkbox = QCheckBox( "Store per-frame Debyer output files" ) self.store_frame_outputs_checkbox.setChecked(False) layout.addRow("", self.store_frame_outputs_checkbox) + self.parallel_jobs_spin = QSpinBox() + self.parallel_jobs_spin.setRange(1, 64) + self.parallel_jobs_spin.setValue(default_parallel_debyer_jobs()) + self.parallel_jobs_spin.setToolTip( + "Run multiple independent Debyer frame calculations at the same " + "time. Use 1 for the old serial behavior." + ) + layout.addRow("Parallel Debyer jobs", self.parallel_jobs_spin) + self.update_plot_during_run_checkbox = QCheckBox( "Update plot while averaging" ) self.update_plot_during_run_checkbox.setChecked(True) self.update_plot_during_run_checkbox.setToolTip( "If enabled, the average PDF plot refreshes during the Debyer " - "run as more frame outputs are included." + "run as more frame outputs are included. You can toggle this " + "while averaging; turning it back on requests the next average." + ) + self.update_plot_during_run_checkbox.toggled.connect( + self._on_update_plot_during_run_toggled ) layout.addRow("", self.update_plot_during_run_checkbox) @@ -753,11 +1041,17 @@ def _build_console_group(self) -> QGroupBox: def _build_plot_controls(self) -> QWidget: widget = QWidget() - layout = QHBoxLayout(widget) + widget.setObjectName("pdfPlotControls") + layout = QVBoxLayout(widget) layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(8) + layout.setSpacing(4) + + selector_row = QHBoxLayout() + selector_row.setContentsMargins(0, 0, 0, 0) + selector_row.setSpacing(8) + layout.addLayout(selector_row) - layout.addWidget(QLabel("Plot")) + selector_row.addWidget(QLabel("Plot")) self.representation_combo = QComboBox() for label in SUPPORTED_PLOT_REPRESENTATIONS: self.representation_combo.addItem(label) @@ -765,9 +1059,9 @@ def _build_plot_controls(self) -> QWidget: self.representation_combo.currentIndexChanged.connect( self._rebuild_traces_and_plot ) - layout.addWidget(self.representation_combo) + selector_row.addWidget(self.representation_combo) - layout.addWidget(QLabel("Partial colors")) + selector_row.addWidget(QLabel("Partial colors")) self.color_scheme_combo = QComboBox() for scheme in DEFAULT_COLOR_SCHEMES: self.color_scheme_combo.addItem(scheme) @@ -775,26 +1069,43 @@ def _build_plot_controls(self) -> QWidget: self.color_scheme_combo.currentIndexChanged.connect( self._apply_color_scheme ) - layout.addWidget(self.color_scheme_combo) + selector_row.addWidget(self.color_scheme_combo) + + self.legend_checkbox = QCheckBox("Legend") + self.legend_checkbox.setChecked(True) + self.legend_checkbox.toggled.connect(self._refresh_plot) + selector_row.addWidget(self.legend_checkbox) + + self.fit_box_checkbox = QCheckBox("Fit Coefficient") + self.fit_box_checkbox.setChecked(True) + self.fit_box_checkbox.toggled.connect(self._refresh_plot) + selector_row.addWidget(self.fit_box_checkbox) + selector_row.addStretch(1) + + trace_row = QHBoxLayout() + trace_row.setContentsMargins(0, 0, 0, 0) + trace_row.setSpacing(8) + layout.addLayout(trace_row) self.average_toggle_button = QPushButton("Hide Average") self.average_toggle_button.clicked.connect(self._toggle_average_trace) - layout.addWidget(self.average_toggle_button) + trace_row.addWidget(self.average_toggle_button) self.partials_toggle_button = QPushButton("Show Partial PDFs") self.partials_toggle_button.clicked.connect( self._toggle_partial_traces ) - layout.addWidget(self.partials_toggle_button) + trace_row.addWidget(self.partials_toggle_button) self.groups_toggle_button = QPushButton("Show Grouped Partials") self.groups_toggle_button.clicked.connect(self._toggle_group_traces) - layout.addWidget(self.groups_toggle_button) + trace_row.addWidget(self.groups_toggle_button) - self.legend_checkbox = QCheckBox("Legend") - self.legend_checkbox.setChecked(True) - self.legend_checkbox.toggled.connect(self._refresh_plot) - layout.addWidget(self.legend_checkbox) + self.experimental_toggle_button = QPushButton("Hide Experimental") + self.experimental_toggle_button.clicked.connect( + self._toggle_experimental_trace + ) + trace_row.addWidget(self.experimental_toggle_button) self.export_active_traces_button = QPushButton( "Export Active Traces..." @@ -802,8 +1113,8 @@ def _build_plot_controls(self) -> QWidget: self.export_active_traces_button.clicked.connect( self._export_active_traces ) - layout.addWidget(self.export_active_traces_button) - layout.addStretch(1) + trace_row.addWidget(self.export_active_traces_button) + trace_row.addStretch(1) return widget def set_project_dir(self, project_dir: str | Path | None) -> None: @@ -848,11 +1159,198 @@ def _choose_frames_dir(self) -> None: self.frames_dir_edit.setText(selected) self._inspect_frames_dir() + def _choose_experimental_file(self) -> None: + start_dir = ( + self.experimental_file_edit.text().strip() + or self.project_dir_edit.text().strip() + or str(Path.home()) + ) + selected_path, _selected_filter = QFileDialog.getOpenFileName( + self, + "Select experimental PDF g(r) data file", + start_dir, + "Data files (*.txt *.dat *.iq);;All files (*)", + ) + if not selected_path: + return + self._load_experimental_file(Path(selected_path).expanduser()) + + def _load_experimental_path_from_edit(self) -> None: + text = self.experimental_file_edit.text().strip() + if not text: + if self._experimental_summary is not None: + self._clear_experimental_file() + return + path = Path(text).expanduser() + if ( + self._experimental_summary is not None + and path.resolve() == self._experimental_summary.path + ): + return + self._load_experimental_file(path) + + def _load_experimental_file(self, file_path: Path) -> None: + resolved = file_path.expanduser().resolve() + if not resolved.is_file(): + QMessageBox.warning( + self, + "Experimental g(r)", + "The selected experimental data file does not exist: " + f"{resolved}", + ) + return + try: + summary = load_experimental_data_file(resolved, skiprows=0) + except Exception: + dialog = self._build_experimental_header_dialog(resolved) + if dialog.exec() != int(QDialog.DialogCode.Accepted): + return + summary = dialog.accepted_summary + if summary is None: + return + self._apply_experimental_file(summary) + + def _build_experimental_header_dialog( + self, + file_path: Path, + ) -> ExperimentalDataHeaderDialog: + initial_summary = ( + self._experimental_summary + if self._experimental_summary is not None + and self._experimental_summary.path == file_path + else None + ) + return ExperimentalDataHeaderDialog( + file_path, + self, + title="Check Experimental g(r) Data File", + independent_column_label="r(A) column", + dependent_column_label="g(r) column", + error_column_label="Error column (optional)", + intro_text=( + "Adjust the number of header rows to skip, confirm which " + "columns correspond to r(A) and g(r), and then load the file." + ), + initial_header_rows=( + initial_summary.header_rows + if initial_summary is not None + else None + ), + initial_q_column=( + initial_summary.q_column + if initial_summary is not None + else None + ), + initial_intensity_column=( + initial_summary.intensity_column + if initial_summary is not None + else None + ), + initial_error_column=( + initial_summary.error_column + if initial_summary is not None + else None + ), + ) + + def _configure_experimental_columns(self) -> None: + if self._experimental_summary is None: + text = self.experimental_file_edit.text().strip() + if not text: + self.experimental_status_label.setText( + "Select an experimental g(r) file before configuring " + "columns." + ) + return + file_path = Path(text).expanduser().resolve() + else: + file_path = self._experimental_summary.path + dialog = self._build_experimental_header_dialog(file_path) + if dialog.exec() != int(QDialog.DialogCode.Accepted): + return + summary = dialog.accepted_summary + if summary is None: + return + self._apply_experimental_file(summary) + + def _apply_experimental_file( + self, + summary: ExperimentalDataSummary, + ) -> None: + self._experimental_summary = summary + self.experimental_file_edit.setText(str(summary.path)) + self._trace_visibility.setdefault("experimental", True) + self._trace_colors.setdefault("experimental", "#d62728") + self.experimental_status_label.setText( + self._experimental_summary_text(summary) + ) + self._append_log( + "Loaded experimental g(r) data from " f"{summary.path}" + ) + self._refresh_experimental_controls() + self._rebuild_traces_and_plot() + + def _clear_experimental_file(self) -> None: + self._experimental_summary = None + self.experimental_file_edit.clear() + self._trace_visibility.pop("experimental", None) + self._trace_tag_visibility.pop("experimental", None) + self.experimental_status_label.setText( + "Optional: load an experimental file with r(A) and g(r) columns." + ) + self._refresh_experimental_controls() + self._rebuild_traces_and_plot() + + def _refresh_experimental_controls(self) -> None: + has_experimental = self._experimental_summary is not None + if hasattr(self, "experimental_columns_button"): + self.experimental_columns_button.setEnabled(has_experimental) + if hasattr(self, "clear_experimental_button"): + self.clear_experimental_button.setEnabled(has_experimental) + if hasattr(self, "experimental_toggle_button"): + self.experimental_toggle_button.setEnabled(has_experimental) + if hasattr(self, "fit_box_checkbox"): + self.fit_box_checkbox.setEnabled(has_experimental) + + def _experimental_summary_text( + self, + summary: ExperimentalDataSummary, + ) -> str: + r_values = np.asarray(summary.q_values, dtype=float) + if r_values.size: + r_range = ( + f"{float(np.nanmin(r_values)):.6g} to " + f"{float(np.nanmax(r_values)):.6g} A" + ) + else: + r_range = "unknown" + return ( + f"Loaded {summary.path.name}: {len(r_values)} points\n" + f"r range: {r_range}\n" + f"Columns: {self._experimental_column_text(summary)}" + ) + + @staticmethod + def _experimental_column_text( + summary: ExperimentalDataSummary, + ) -> str: + def _column_label(index: int | None, fallback: str) -> str: + if index is None: + return "None" + if 0 <= index < len(summary.column_names): + return summary.column_names[index] + return fallback + + return ( + f"r(A)={_column_label(summary.q_column, 'Column 1')}; " + f"g(r)={_column_label(summary.intensity_column, 'Column 2')}" + ) + def _inspect_frames_dir(self) -> None: text = self.frames_dir_edit.text().strip() if not text: self.frames_summary_label.setText( - "Select a trajectory frame folder containing only .xyz or only .pdb files." + "Select a trajectory frame folder containing .xyz files." ) return try: @@ -880,10 +1378,18 @@ def _inspect_frames_dir(self) -> None: self.box_b_edit.setText(f"{detected_box[1]:g}") if not self.box_c_edit.text().strip(): self.box_c_edit.setText(f"{detected_box[2]:g}") + inferred_solutes = infer_default_solute_elements( + inspection.element_counts + ) + if inferred_solutes and not self.solute_elements_edit.text().strip(): + self.solute_elements_edit.setText(", ".join(inferred_solutes)) element_summary = ", ".join( f"{element}{count if count != 1 else ''}" for element, count in sorted(inspection.element_counts.items()) ) + solute_summary = ( + ", ".join(inferred_solutes) if inferred_solutes else "not inferred" + ) box_summary = "unknown" if inspection.detected_box_dimensions is not None: box_summary = ( @@ -907,6 +1413,7 @@ def _inspect_frames_dir(self) -> None: f"Detected {inspection.frame_format.upper()} frames: " f"{len(inspection.frame_paths)} files\n" f"Elements in first frame: {element_summary or 'unknown'}\n" + f"Default solutes: {solute_summary}\n" f"Bounding box: {box_summary}" ) self._update_rho0_label() @@ -1011,6 +1518,12 @@ def _calculation_summary_text( f"{calculation.to_value:g} A (step {calculation.step_value:g})" ), f"rho0: {calculation.rho0:.6g} atoms/A^3", + "Solute elements: " + + ( + ", ".join(calculation.solute_elements) + if calculation.solute_elements + else "None" + ), f"Frames folder: {calculation.frames_dir}", ] if calculation.elapsed_seconds is not None: @@ -1038,20 +1551,69 @@ def _parse_box_dimensions(self) -> tuple[float, float, float]: float(self.box_c_edit.text().strip()), ) + def _coerce_r_range_maximum_for_box( + self, + r_max: float, + box_dimensions: tuple[float, float, float], + ) -> float: + box_values = np.asarray(box_dimensions, dtype=float) + if ( + box_values.size != 3 + or not np.all(np.isfinite(box_values)) + or np.any(box_values <= 0.0) + ): + return r_max + allowed_r_max = float(np.min(box_values) * 0.5) + if r_max <= allowed_r_max: + return r_max + self.to_edit.setText(f"{allowed_r_max:g}") + self.statusBar().showMessage( + "Adjusted r-range maximum to half of the minimum box dimension.", + 5000, + ) + return allowed_r_max + def _parse_solute_elements(self) -> tuple[str, ...]: raw = self.solute_elements_edit.text().strip() if not raw: return () values = [token.strip() for token in raw.replace(";", ",").split(",")] - cleaned = sorted( - { - value[:1].upper() + value[1:].lower() - for value in values - if value - } - ) + cleaned: list[str] = [] + seen: set[str] = set() + for value in values: + if not value: + continue + element = value[:1].upper() + value[1:].lower() + if element in seen: + continue + cleaned.append(element) + seen.add(element) return tuple(cleaned) + def _apply_solute_groups_from_ui(self) -> None: + solute_elements = self._parse_solute_elements() + if self._current_calculation is None: + self.statusBar().showMessage( + "Solute elements will be used for the next Debyer run.", + 4000, + ) + return + self._current_calculation = replace( + self._current_calculation, + solute_elements=solute_elements, + target_peak_markers={}, + ) + rewrite_debyer_calculation_output(self._current_calculation) + self._persist_current_calculation() + self.calculation_info_label.setText( + self._calculation_summary_text(self._current_calculation) + ) + self._rebuild_traces_and_plot() + self.statusBar().showMessage( + "Updated grouped partial traces from the solute elements.", + 4000, + ) + def _suggest_project_dir(self) -> Path: frames_dir = self.frames_dir_edit.text().strip() if frames_dir: @@ -1069,6 +1631,12 @@ def _build_settings(self) -> DebyerPDFSettings: if not frames_text: raise ValueError("Select a frames folder before running Debyer.") + box_dimensions = self._parse_box_dimensions() + to_value = self._coerce_r_range_maximum_for_box( + float(self.to_edit.text().strip()), + box_dimensions, + ) + return DebyerPDFSettings( project_dir=Path(project_text).expanduser().resolve(), frames_dir=Path(frames_text).expanduser().resolve(), @@ -1076,14 +1644,15 @@ def _build_settings(self) -> DebyerPDFSettings: or "debyer_pdf", mode=self.mode_combo.currentText(), from_value=float(self.from_edit.text().strip()), - to_value=float(self.to_edit.text().strip()), + to_value=to_value, step_value=float(self.step_edit.text().strip()), - box_dimensions=self._parse_box_dimensions(), + box_dimensions=box_dimensions, atom_count=int(float(self.atom_count_edit.text().strip())), store_frame_outputs=bool( self.store_frame_outputs_checkbox.isChecked() ), solute_elements=self._parse_solute_elements(), + max_parallel_jobs=int(self.parallel_jobs_spin.value()), ) def _current_peak_finder_settings_from_ui( @@ -1492,6 +2061,49 @@ def _sanitize_trace_column_name(value: str) -> str: ).strip("_") return text or "trace" + def _trace_x_values(self, trace: dict[str, object]) -> np.ndarray: + if "x_values" in trace: + return np.asarray(trace["x_values"], dtype=float) + if self._current_calculation is None: + return np.asarray([], dtype=float) + return np.asarray(self._current_calculation.r_values, dtype=float) + + def _trace_values_on_export_grid( + self, + trace: dict[str, object], + ) -> np.ndarray: + if self._current_calculation is None: + return np.asarray([], dtype=float) + target_r = np.asarray(self._current_calculation.r_values, dtype=float) + source_r = self._trace_x_values(trace) + source_values = np.asarray(trace["values"], dtype=float) + if source_r.size == target_r.size and np.allclose( + source_r, + target_r, + ): + return source_values + valid = np.isfinite(source_r) & np.isfinite(source_values) + if valid.sum() < 2: + return np.full_like(target_r, np.nan, dtype=float) + source_r = source_r[valid] + source_values = source_values[valid] + order = np.argsort(source_r) + source_r = source_r[order] + source_values = source_values[order] + unique_r, unique_indices = np.unique(source_r, return_index=True) + source_r = unique_r + source_values = source_values[unique_indices] + if source_r.size < 2: + return np.full_like(target_r, np.nan, dtype=float) + interpolated = np.full_like(target_r, np.nan, dtype=float) + inside = (target_r >= source_r[0]) & (target_r <= source_r[-1]) + interpolated[inside] = np.interp( + target_r[inside], + source_r, + source_values, + ) + return interpolated + def _active_trace_export_columns( self, ) -> tuple[list[str], list[np.ndarray]] | None: @@ -1508,7 +2120,7 @@ def _active_trace_export_columns( column_names.append( self._sanitize_trace_column_name(str(trace["label"])) ) - column_arrays.append(np.asarray(trace["values"], dtype=float)) + column_arrays.append(self._trace_values_on_export_grid(trace)) if len(column_names) <= 1: return None return column_names, column_arrays @@ -1518,10 +2130,11 @@ def _default_active_trace_export_path(self) -> Path: self.representation_combo.currentText() ) if self._current_calculation is not None: - return ( - self._current_calculation.calculation_dir - / f"{self._current_calculation.filename_prefix}_{representation}_active_traces.txt" + filename = ( + f"{self._current_calculation.filename_prefix}_" + f"{representation}_active_traces.txt" ) + return self._current_calculation.calculation_dir / filename project_text = self.project_dir_edit.text().strip() root = ( Path(project_text).expanduser().resolve() @@ -1640,9 +2253,13 @@ def _start_calculation(self) -> None: "Estimated time remaining: collecting initial timing samples..." ) self.console.clear() + self._latest_run_preview = None self._run_thread = QThread(self) - self._run_worker = DebyerPDFWorker(settings) + self._run_worker = DebyerPDFWorker( + settings, + preview_enabled=self.update_plot_during_run_checkbox.isChecked(), + ) self._run_worker.moveToThread(self._run_thread) self._run_thread.started.connect(self._run_worker.run) self._run_worker.log.connect(self._append_log) @@ -1657,6 +2274,18 @@ def _start_calculation(self) -> None: self._run_thread.finished.connect(self._run_thread.deleteLater) self._run_thread.start() + def _request_run_cancel(self, message: str) -> None: + self._close_requested_during_run = True + self.calculate_button.setEnabled(False) + self.progress_label.setText("Progress: stopping active Debyer jobs") + self.time_estimate_label.setText( + "Estimated time remaining: stopping active Debyer jobs..." + ) + self.statusBar().showMessage(message, 5000) + self._append_log(message) + if self._run_worker is not None: + self._run_worker.request_cancel() + def _append_log(self, message: str) -> None: self.console.append(message) @@ -1674,10 +2303,36 @@ def _on_progress( def _on_status(self, message: str) -> None: self.statusBar().showMessage(message) + def _on_update_plot_during_run_toggled(self, checked: bool) -> None: + if self._run_worker is not None: + self._run_worker.set_preview_enabled(checked) + if checked and self._latest_run_preview is not None: + self._apply_loaded_calculation(self._latest_run_preview) + self.statusBar().showMessage( + "Plot updates resumed; showing the latest available average.", + 3000, + ) + elif ( + checked + and self._run_thread is not None + and self._run_thread.isRunning() + ): + self.statusBar().showMessage( + "Plot updates resumed; the next completed frame will refresh " + "the average.", + 3000, + ) + elif self._run_thread is not None and self._run_thread.isRunning(): + self.statusBar().showMessage( + "Plot updates paused; averaging will continue.", + 3000, + ) + def _on_preview_update(self, result: object) -> None: calculation = result if not isinstance(calculation, DebyerPDFCalculation): return + self._latest_run_preview = calculation if not self.update_plot_during_run_checkbox.isChecked(): return self._apply_loaded_calculation(calculation) @@ -1692,10 +2347,25 @@ def _on_finished(self, result: object) -> None: "The Debyer worker finished without returning a valid calculation.", ) return - self._append_log("Debyer calculation completed successfully.") - self.time_estimate_label.setText("Estimated time remaining: 00:00") + if calculation.is_partial_average: + processed_frames = ( + calculation.frame_count + if calculation.processed_frame_count is None + else int(calculation.processed_frame_count) + ) + self._append_log( + "Debyer calculation stopped early; saved running average " + f"after {processed_frames}/{calculation.frame_count} frames." + ) + self.time_estimate_label.setText( + "Estimated time remaining: stopped early" + ) + else: + self._append_log("Debyer calculation completed successfully.") + self.time_estimate_label.setText("Estimated time remaining: 00:00") self._refresh_saved_calculations() self._apply_loaded_calculation(calculation) + self._latest_run_preview = None def _on_failed(self, message: str) -> None: self.calculate_button.setEnabled(True) @@ -1703,6 +2373,9 @@ def _on_failed(self, message: str) -> None: self.time_estimate_label.setText( "Estimated time remaining: unavailable" ) + self._latest_run_preview = None + if self._close_requested_during_run: + return QMessageBox.warning(self, "Debyer calculation failed", message) def _cleanup_run_thread(self) -> None: @@ -1710,11 +2383,156 @@ def _cleanup_run_thread(self) -> None: self._run_worker.deleteLater() self._run_worker = None self._run_thread = None + self._close_requested_during_run = False + + def _coordination_fit_trace_candidates(self) -> list[dict[str, object]]: + if self._current_calculation is None: + return [] + return [ + trace + for trace in build_display_traces( + self._current_calculation, + representation="R(r)", + include_grouped_partials=True, + ) + if str(trace.get("kind", "")) in {"average", "partial", "group"} + ] + + def _refresh_coordination_fit_trace_combo(self) -> None: + if not hasattr(self, "coordination_fit_trace_combo"): + return + previous_key = self.coordination_fit_trace_combo.currentData() + traces = self._coordination_fit_trace_candidates() + self.coordination_fit_trace_combo.blockSignals(True) + self.coordination_fit_trace_combo.clear() + selected_index = 0 + for index, trace in enumerate(traces): + key = str(trace["key"]) + if previous_key == key: + selected_index = index + self.coordination_fit_trace_combo.addItem( + str(trace["label"]), + key, + ) + if traces: + self.coordination_fit_trace_combo.setCurrentIndex(selected_index) + self.coordination_fit_button.setEnabled(True) + else: + self.coordination_fit_button.setEnabled(False) + self.coordination_fit_status_label.setText( + "Load or calculate a Debyer result before fitting R(r)." + ) + self.coordination_fit_trace_combo.blockSignals(False) + if traces: + self._suggest_coordination_fit_window() + + def _selected_coordination_fit_trace( + self, + ) -> dict[str, object] | None: + selected_key = self.coordination_fit_trace_combo.currentData() + if not selected_key: + return None + for trace in self._coordination_fit_trace_candidates(): + if str(trace["key"]) == str(selected_key): + return trace + return None + + def _suggest_coordination_fit_window(self, *_args) -> None: + trace = self._selected_coordination_fit_trace() + if trace is None: + return + r_values = self._trace_x_values(trace) + values = np.asarray(trace["values"], dtype=float) + valid = np.isfinite(r_values) & np.isfinite(values) + if valid.sum() < 5: + return + fit_r = r_values[valid] + fit_values = values[valid] + order = np.argsort(fit_r) + fit_r = fit_r[order] + fit_values = fit_values[order] + peak_index = int(np.nanargmax(fit_values)) + peak_r = float(fit_r[peak_index]) + span = max(float(fit_r[-1] - fit_r[0]), 1.0e-6) + half_width = max(0.35, span * 0.08) + r_min = max(float(fit_r[0]), peak_r - half_width) + r_max = min(float(fit_r[-1]), peak_r + half_width) + if r_max <= r_min: + r_min = float(fit_r[0]) + r_max = float(fit_r[-1]) + self.coordination_fit_r_min_spin.setValue(r_min) + self.coordination_fit_r_max_spin.setValue(r_max) + self.coordination_fit_center_spin.setValue(peak_r) + self.coordination_fit_sigma_spin.setValue( + max((r_max - r_min) / 6.0, 0.0001) + ) + + def _fit_coordination_number(self) -> None: + trace = self._selected_coordination_fit_trace() + if self._current_calculation is None or trace is None: + self.coordination_fit_status_label.setText( + "Load or calculate a Debyer result before fitting R(r)." + ) + return + try: + result = fit_coordination_peak_from_r( + r_values=self._trace_x_values(trace), + r_distribution_values=np.asarray(trace["values"], dtype=float), + r_min=float(self.coordination_fit_r_min_spin.value()), + r_max=float(self.coordination_fit_r_max_spin.value()), + initial_center=float( + self.coordination_fit_center_spin.value() + ), + initial_sigma=float(self.coordination_fit_sigma_spin.value()), + ) + except Exception as exc: + self.coordination_fit_status_label.setText(str(exc)) + return + + trace_key = str(trace["key"]) + trace_label = str(trace["label"]) + self.representation_combo.setCurrentText("R(r)") + self._trace_visibility[trace_key] = True + self._refresh_trace_table() + self._refresh_plot() + self._append_coordination_fit_result(trace_label, result) + self.coordination_fit_status_label.setText( + f"{trace_label}: CN = {result.coordination_number:.4g}; " + f"center = {result.center:.4g} A; sigma = {result.sigma:.4g} A" + ) + + def _append_coordination_fit_result( + self, + trace_label: str, + result, + ) -> None: + row = self.coordination_fit_results_table.rowCount() + self.coordination_fit_results_table.insertRow(row) + values = [ + trace_label, + f"{result.r_min:.5g}", + f"{result.r_max:.5g}", + f"{result.center:.5g}", + f"{result.sigma:.5g}", + f"{result.coordination_number:.6g}", + f"{result.amplitude:.6g}", + ( + "nan" + if not np.isfinite(result.r_squared) + else f"{result.r_squared:.6g}" + ), + f"{result.rmse:.6g}", + ] + for column, value in enumerate(values): + item = QTableWidgetItem(str(value)) + item.setFlags(item.flags() ^ Qt.ItemFlag.ItemIsEditable) + self.coordination_fit_results_table.setItem(row, column, item) def _rebuild_traces_and_plot(self) -> None: if self._current_calculation is None: self._current_traces = [] self._refresh_trace_table() + self._refresh_coordination_fit_trace_combo() self._refresh_plot() return previous_visibility = dict(self._trace_visibility) @@ -1725,6 +2543,31 @@ def _rebuild_traces_and_plot(self) -> None: representation=self.representation_combo.currentText(), include_grouped_partials=True, ) + if self._experimental_summary is not None: + experimental_r = np.asarray( + self._experimental_summary.q_values, + dtype=float, + ) + experimental_values = convert_distribution_values( + self._experimental_summary.intensities, + r_values=experimental_r, + rho0=self._current_calculation.rho0, + source_mode="PDF", + target_representation=self.representation_combo.currentText(), + is_component=False, + ) + self._current_traces.append( + { + "key": "experimental", + "label": ( + "Experimental g(r) " + f"({self._experimental_summary.path.name})" + ), + "kind": "experimental", + "x_values": experimental_r, + "values": experimental_values, + } + ) for trace in self._current_traces: key = str(trace["key"]) kind = str(trace["kind"]) @@ -1732,6 +2575,8 @@ def _rebuild_traces_and_plot(self) -> None: self._trace_visibility[key] = bool(previous_visibility[key]) elif kind == "average": self._trace_visibility[key] = True + elif kind == "experimental": + self._trace_visibility[key] = True else: self._trace_visibility[key] = False if key in previous_tag_visibility: @@ -1747,6 +2592,7 @@ def _rebuild_traces_and_plot(self) -> None: if key in previous_colors: self._trace_colors[key] = str(previous_colors[key]) self._apply_color_scheme(preserve_existing=True) + self._refresh_coordination_fit_trace_combo() def _apply_color_scheme( self, *_args, preserve_existing: bool = False @@ -1760,17 +2606,21 @@ def _apply_color_scheme( colored_traces = [ trace for trace in self._current_traces - if str(trace["kind"]) != "average" + if str(trace["kind"]) not in {"average", "experimental"} ] count = max(len(colored_traces), 1) for index, trace in enumerate(colored_traces): key = str(trace["key"]) if preserve_existing and key in self._trace_colors: continue + if key in _GROUP_TRACE_DEFAULT_COLORS: + self._trace_colors[key] = _GROUP_TRACE_DEFAULT_COLORS[key] + continue rgba = scheme(index / max(count - 1, 1)) color = QColor.fromRgbF(rgba[0], rgba[1], rgba[2], rgba[3]).name() self._trace_colors[key] = color self._trace_colors["average"] = "#000000" + self._trace_colors.setdefault("experimental", "#d62728") self._refresh_trace_table() self._refresh_plot() @@ -1798,6 +2648,14 @@ def _toggle_group_traces(self) -> None: ] self._toggle_trace_keys(keys) + def _toggle_experimental_trace(self) -> None: + keys = [ + str(trace["key"]) + for trace in self._current_traces + if str(trace["kind"]) == "experimental" + ] + self._toggle_trace_keys(keys) + def _toggle_trace_keys(self, keys: list[str]) -> None: if not keys: return @@ -1832,12 +2690,14 @@ def _refresh_plot(self, *_args) -> None: key = str(trace["key"]) if not self._trace_visibility.get(key, False): continue + kind = str(trace["kind"]) line = axis.plot( - self._current_calculation.r_values, + self._trace_x_values(trace), np.asarray(trace["values"], dtype=float), color=self._trace_colors.get(key, "#000000"), - linewidth=2.0 if str(trace["kind"]) == "average" else 1.4, - alpha=1.0 if str(trace["kind"]) == "average" else 0.9, + linewidth=2.0 if kind == "average" else 1.4, + alpha=1.0 if kind == "average" else 0.9, + linestyle="--" if kind == "experimental" else "-", label=str(trace["label"]), )[0] plotted.append(line) @@ -1861,6 +2721,7 @@ def _refresh_plot(self, *_args) -> None: labelsize=max(float(self.axis_label_size_spin.value()) - 1.0, 1.0) ) self._draw_peak_tags(axis) + self._draw_fit_metrics_box(axis) if plotted and self.legend_checkbox.isChecked(): axis.legend(loc="best", fontsize="small") elif not plotted: @@ -1877,6 +2738,65 @@ def _refresh_plot(self, *_args) -> None: self._update_toggle_button_labels() self.canvas.draw_idle() + def _experimental_fit_metrics_text(self) -> str | None: + if self._current_calculation is None: + return None + if self._experimental_summary is None: + return None + model_g = convert_distribution_values( + self._current_calculation.total_values, + r_values=self._current_calculation.r_values, + rho0=self._current_calculation.rho0, + source_mode=self._current_calculation.mode, + target_representation="g(r)", + is_component=False, + ) + metrics = compute_experimental_fit_metrics( + model_r_values=self._current_calculation.r_values, + model_g_values=model_g, + experimental_r_values=self._experimental_summary.q_values, + experimental_g_values=self._experimental_summary.intensities, + ) + if metrics is None: + return "Fit unavailable\nNo overlapping r range" + r_squared_text = ( + "nan" + if not np.isfinite(metrics.r_squared) + else f"{metrics.r_squared:.4f}" + ) + return ( + "AIMD g(r) vs experimental\n" + f"R^2 = {r_squared_text}\n" + f"RMSE = {metrics.rmse:.4g}\n" + f"MAE = {metrics.mae:.4g}\n" + f"n = {metrics.point_count}" + ) + + def _draw_fit_metrics_box(self, axis) -> None: + if not hasattr(self, "fit_box_checkbox"): + return + if not self.fit_box_checkbox.isChecked(): + return + metrics_text = self._experimental_fit_metrics_text() + if metrics_text is None: + return + axis.text( + 0.02, + 0.04, + metrics_text, + transform=axis.transAxes, + ha="left", + va="bottom", + fontsize="small", + bbox={ + "boxstyle": "round,pad=0.35", + "facecolor": "#ffffff", + "edgecolor": "#666666", + "alpha": 0.88, + }, + zorder=10, + ) + def _draw_peak_tags(self, axis) -> None: if self._current_calculation is None: return @@ -2261,6 +3181,14 @@ def _update_toggle_button_labels(self) -> None: group_visible = any( self._trace_visibility.get(key, False) for key in group_keys ) + experimental_keys = [ + str(trace["key"]) + for trace in self._current_traces + if str(trace["kind"]) == "experimental" + ] + experimental_visible = any( + self._trace_visibility.get(key, False) for key in experimental_keys + ) self.average_toggle_button.setText( "Hide Average" if average_visible else "Show Average" ) @@ -2273,6 +3201,13 @@ def _update_toggle_button_labels(self) -> None: if group_visible else "Show Grouped Partials" ) + if hasattr(self, "experimental_toggle_button"): + self.experimental_toggle_button.setEnabled(bool(experimental_keys)) + self.experimental_toggle_button.setText( + "Hide Experimental" + if experimental_visible + else "Show Experimental" + ) @staticmethod def _format_duration(seconds: float | None) -> str: diff --git a/src/saxshell/pdf/debyer/workflow.py b/src/saxshell/pdf/debyer/workflow.py index 4dcc7c9..e1ab7ca 100644 --- a/src/saxshell/pdf/debyer/workflow.py +++ b/src/saxshell/pdf/debyer/workflow.py @@ -1,17 +1,20 @@ from __future__ import annotations +import concurrent.futures import json import math +import os import shutil import subprocess import time from collections.abc import Callable -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from datetime import datetime, timezone from pathlib import Path from typing import Any import numpy as np +from scipy.optimize import curve_fit from saxshell.cluster.clusternetwork import ( detect_frame_folder_mode, @@ -29,8 +32,42 @@ SUPPORTED_DEBYER_MODES = ("PDF", "RDF", "rPDF") SUPPORTED_PLOT_REPRESENTATIONS = ("g(r)", "G(r)", "R(r)") DEFAULT_COLOR_SCHEMES = ("tab20", "tab10", "viridis", "plasma", "summer") +GROUPED_PARTIAL_COLUMN_LABELS = ( + "solute-solute", + "solute-solvent", + "solvent-solvent", +) _COLUMN_PREFIX = "# columns:" -_TIME_PREDICTION_UPDATE_INTERVAL_FRAMES = 5 +_DEFAULT_AVERAGE_CHECKPOINT_INTERVAL_FRAMES = 1000 +_MIN_AVERAGE_CHECKPOINT_INTERVAL_FRAMES = 100 +_RUNNING_AVERAGE_MEMORY_TARGET_BYTES = 256 * 1024 * 1024 +_MAX_PARALLEL_DEBYER_JOBS = 64 + + +def default_parallel_debyer_jobs(cpu_count: int | None = None) -> int: + available_cpus = os.cpu_count() if cpu_count is None else cpu_count + return max(1, min(4, int(available_cpus or 1))) + + +def _coerce_parallel_debyer_jobs(value: object) -> int: + try: + requested = int(value) + except (TypeError, ValueError): + requested = 1 + if requested <= 0: + requested = default_parallel_debyer_jobs() + return max(1, min(requested, _MAX_PARALLEL_DEBYER_JOBS)) + + +def _resolve_parallel_debyer_jobs( + value: object, + *, + total_frames: int, +) -> int: + return min( + _coerce_parallel_debyer_jobs(value), + max(int(total_frames), 1), + ) @dataclass(slots=True, frozen=True) @@ -68,6 +105,7 @@ class DebyerPDFSettings: atom_count: int = 0 store_frame_outputs: bool = False solute_elements: tuple[str, ...] = () + max_parallel_jobs: int = 1 @dataclass(slots=True, frozen=True) @@ -87,6 +125,32 @@ class DebyerPeakMarker: source: str = "auto" +@dataclass(slots=True, frozen=True) +class DebyerFitMetrics: + r_squared: float + rmse: float + mae: float + point_count: int + r_min: float + r_max: float + + +@dataclass(slots=True, frozen=True) +class DebyerCoordinationFitResult: + r_min: float + r_max: float + center: float + sigma: float + coordination_number: float + amplitude: float + baseline_intercept: float + baseline_slope: float + rmse: float + r_squared: float + point_count: int + fitted_values: np.ndarray + + @dataclass(slots=True, frozen=True) class DebyerPDFCalculationSummary: calculation_id: str @@ -121,6 +185,7 @@ class DebyerPDFCalculation: frame_output_dir: Path | None averaged_output_file: Path solute_elements: tuple[str, ...] + parallel_jobs: int r_values: np.ndarray total_values: np.ndarray partial_values: dict[str, np.ndarray] @@ -152,12 +217,30 @@ def _normalize_solute_elements( ) -> tuple[str, ...]: if not values: return () - normalized = { - _normalized_element(value) - for value in values - if _normalized_element(value) + normalized: list[str] = [] + seen: set[str] = set() + for value in values: + element = _normalized_element(value) + if not element or element in seen: + continue + normalized.append(element) + seen.add(element) + return tuple(normalized) + + +def infer_default_solute_elements( + element_counts: dict[str, int] | list[str] | tuple[str, ...] | set[str], +) -> tuple[str, ...]: + available = { + _normalized_element(element) + for element in element_counts + if _normalized_element(element) } - return tuple(sorted(normalized)) + if {"Cs", "Pb", "I"}.issubset(available): + return ("Cs", "Pb", "I") + if {"Pb", "I"}.issubset(available): + return ("Pb", "I") + return () def _sanitize_prefix(value: str) -> str: @@ -398,17 +481,21 @@ def check_debyer_runtime( def inspect_frames_dir(frames_dir: str | Path) -> DebyerFrameInspection: resolved_frames_dir = Path(frames_dir).expanduser().resolve() frame_format, frame_paths = detect_frame_folder_mode(resolved_frames_dir) + if frame_format != "xyz": + raise ValueError( + "Debyer PDF calculations require XYZ frame files. Convert PDB " + "frames to XYZ before using pdfsetup." + ) first_frame = frame_paths[0] detected_box_dimensions: tuple[float, float, float] | None = None detected_box_source: str | None = None detected_box_source_kind: str | None = None - if frame_format == "xyz": - detected = detect_source_box_dimensions(resolved_frames_dir) - if detected is not None: - detected_box_dimensions, source_path = detected - detected_box_source = source_path.name - detected_box_source_kind = "source_filename" + detected = detect_source_box_dimensions(resolved_frames_dir) + if detected is not None: + detected_box_dimensions, source_path = detected + detected_box_source = source_path.name + detected_box_source_kind = "source_filename" coordinates, elements = load_structure_file(first_frame) estimated_box_dimensions = estimate_box_dimensions_from_coordinates( @@ -515,12 +602,27 @@ def _format_duration(seconds: float | None) -> str: return f"{minutes:02d}:{secs:02d}" -def _time_prediction_interval(total_frames: int) -> int: - if total_frames <= 10: - return 1 - return min( - _TIME_PREDICTION_UPDATE_INTERVAL_FRAMES, max(total_frames // 10, 1) +def _average_checkpoint_interval( + *, + total_frames: int, + average_state_bytes: int = 0, +) -> int: + resolved_total = max(int(total_frames), 1) + if average_state_bytes <= _RUNNING_AVERAGE_MEMORY_TARGET_BYTES: + return min( + _DEFAULT_AVERAGE_CHECKPOINT_INTERVAL_FRAMES, + resolved_total, + ) + pressure = int( + math.ceil( + average_state_bytes / float(_RUNNING_AVERAGE_MEMORY_TARGET_BYTES) + ) + ) + memory_interval = max( + _MIN_AVERAGE_CHECKPOINT_INTERVAL_FRAMES, + _DEFAULT_AVERAGE_CHECKPOINT_INTERVAL_FRAMES // max(pressure, 1), ) + return min(memory_interval, resolved_total) def _estimate_runtime( @@ -550,6 +652,7 @@ def _build_averaged_output_metadata( elapsed_seconds: float | None, estimated_remaining_seconds: float | None, expected_total_seconds: float | None, + parallel_jobs: int | None = None, ) -> dict[str, object]: return { "calculation_id": calculation_id, @@ -572,6 +675,11 @@ def _build_averaged_output_metadata( "rho0": f"{rho0:.8g}", "store_frame_outputs": settings.store_frame_outputs, "solute_elements": ", ".join(settings.solute_elements) or "None", + "parallel_jobs": int( + settings.max_parallel_jobs + if parallel_jobs is None + else parallel_jobs + ), "elapsed_seconds": ( None if elapsed_seconds is None @@ -629,6 +737,70 @@ def _average_frame_outputs( return reference_r, union_columns, averaged +@dataclass(slots=True) +class _RunningDebyerAverage: + reference_r: np.ndarray | None = None + column_order: list[str] = field(default_factory=lambda: ["sum"]) + sums: dict[str, np.ndarray] = field(default_factory=dict) + processed_count: int = 0 + + def add_frame( + self, + r_values: np.ndarray, + columns: dict[str, np.ndarray], + ) -> None: + radial = np.asarray(r_values, dtype=float) + if self.reference_r is None: + self.reference_r = radial.copy() + self.sums = { + "sum": np.zeros_like(self.reference_r, dtype=float), + } + elif not np.allclose(self.reference_r, radial): + raise ValueError( + "Debyer frame outputs do not share the same radial grid." + ) + + for key in columns: + if key not in self.sums: + self.column_order.append(key) + self.sums[key] = np.zeros_like(self.reference_r, dtype=float) + + for key, values in columns.items(): + self.sums[key] += np.asarray(values, dtype=float) + self.processed_count += 1 + + @property + def memory_bytes(self) -> int: + total = 0 + if self.reference_r is not None: + total += int(self.reference_r.nbytes) + total += sum(int(values.nbytes) for values in self.sums.values()) + return total + + def average( + self, + ) -> tuple[np.ndarray, list[str], dict[str, np.ndarray]]: + if self.reference_r is None or self.processed_count <= 0: + raise ValueError( + "No Debyer frame outputs were provided for averaging." + ) + averaged = { + key: np.asarray(self.sums[key], dtype=float) + / float(self.processed_count) + for key in self.column_order + } + return self.reference_r.copy(), list(self.column_order), averaged + + +@dataclass(slots=True, frozen=True) +class _DebyerFrameRunResult: + frame_index: int + frame_path: Path + output_path: Path + r_values: np.ndarray + values: dict[str, np.ndarray] + + def _candidate_peak_indices(values: np.ndarray) -> list[int]: array = np.asarray(values, dtype=float) count = int(array.size) @@ -768,6 +940,7 @@ def build_debyer_calculation_metadata( ), "averaged_output_file": str(calculation.averaged_output_file), "solute_elements": list(calculation.solute_elements), + "parallel_jobs": int(calculation.parallel_jobs), "elapsed_seconds": calculation.elapsed_seconds, "estimated_remaining_seconds": calculation.estimated_remaining_seconds, "expected_total_seconds": calculation.expected_total_seconds, @@ -850,6 +1023,229 @@ def convert_distribution_values( return prefactor_r * (canonical_g - 1.0) +def compute_experimental_fit_metrics( + *, + model_r_values: np.ndarray, + model_g_values: np.ndarray, + experimental_r_values: np.ndarray, + experimental_g_values: np.ndarray, +) -> DebyerFitMetrics | None: + model_r = np.asarray(model_r_values, dtype=float) + model_g = np.asarray(model_g_values, dtype=float) + experimental_r = np.asarray(experimental_r_values, dtype=float) + experimental_g = np.asarray(experimental_g_values, dtype=float) + model_mask = np.isfinite(model_r) & np.isfinite(model_g) + experimental_mask = np.isfinite(experimental_r) & np.isfinite( + experimental_g + ) + if model_mask.sum() < 2 or experimental_mask.sum() < 2: + return None + + model_r = model_r[model_mask] + model_g = model_g[model_mask] + order = np.argsort(model_r) + model_r = model_r[order] + model_g = model_g[order] + unique_r, unique_indices = np.unique(model_r, return_index=True) + model_r = unique_r + model_g = model_g[unique_indices] + if model_r.size < 2: + return None + + experimental_r = experimental_r[experimental_mask] + experimental_g = experimental_g[experimental_mask] + overlap_mask = (experimental_r >= model_r[0]) & ( + experimental_r <= model_r[-1] + ) + if overlap_mask.sum() < 2: + return None + + overlap_r = experimental_r[overlap_mask] + overlap_g = experimental_g[overlap_mask] + interpolated_model = np.interp(overlap_r, model_r, model_g) + residuals = interpolated_model - overlap_g + sse = float(np.sum(residuals**2)) + centered = overlap_g - float(np.mean(overlap_g)) + sst = float(np.sum(centered**2)) + r_squared = float("nan") if sst <= 0.0 else 1.0 - (sse / sst) + rmse = float(np.sqrt(np.mean(residuals**2))) + mae = float(np.mean(np.abs(residuals))) + return DebyerFitMetrics( + r_squared=r_squared, + rmse=rmse, + mae=mae, + point_count=int(overlap_r.size), + r_min=float(np.min(overlap_r)), + r_max=float(np.max(overlap_r)), + ) + + +def _coordination_gaussian_model( + radial: np.ndarray, + area: float, + center: float, + sigma: float, + baseline_intercept: float, + baseline_slope: float, + *, + baseline_pivot: float, +) -> np.ndarray: + radial_values = np.asarray(radial, dtype=float) + bounded_sigma = max(float(sigma), 1.0e-12) + gaussian = ( + float(area) + / (bounded_sigma * math.sqrt(2.0 * math.pi)) + * np.exp(-0.5 * ((radial_values - center) / bounded_sigma) ** 2) + ) + baseline = float(baseline_intercept) + float(baseline_slope) * ( + radial_values - float(baseline_pivot) + ) + return baseline + gaussian + + +def fit_coordination_peak_from_r( + *, + r_values: np.ndarray, + r_distribution_values: np.ndarray, + r_min: float, + r_max: float, + initial_center: float | None = None, + initial_sigma: float | None = None, +) -> DebyerCoordinationFitResult: + radial = np.asarray(r_values, dtype=float) + values = np.asarray(r_distribution_values, dtype=float) + if radial.shape != values.shape: + raise ValueError("R(r) fit inputs must share the same shape.") + if float(r_min) >= float(r_max): + raise ValueError("The R(r) fit minimum must be below maximum.") + + mask = ( + np.isfinite(radial) + & np.isfinite(values) + & (radial >= float(r_min)) + & (radial <= float(r_max)) + ) + if mask.sum() < 5: + raise ValueError( + "At least five finite R(r) points are required inside the fit window." + ) + fit_r = radial[mask] + fit_values = values[mask] + order = np.argsort(fit_r) + fit_r = fit_r[order] + fit_values = fit_values[order] + + window_width = float(fit_r[-1] - fit_r[0]) + if window_width <= 0.0: + raise ValueError("The R(r) fit window has zero radial width.") + edge_count = max(1, min(3, fit_r.size // 4)) + left_r = float(np.mean(fit_r[:edge_count])) + right_r = float(np.mean(fit_r[-edge_count:])) + left_y = float(np.mean(fit_values[:edge_count])) + right_y = float(np.mean(fit_values[-edge_count:])) + baseline_slope = ( + 0.0 + if abs(right_r - left_r) < 1.0e-12 + else (right_y - left_y) / (right_r - left_r) + ) + baseline_guess = left_y + baseline_slope * (fit_r - left_r) + residual_guess = fit_values - baseline_guess + center_guess = ( + float(initial_center) + if initial_center is not None + else float(fit_r[int(np.nanargmax(residual_guess))]) + ) + center_guess = min(max(center_guess, float(fit_r[0])), float(fit_r[-1])) + sigma_guess = ( + float(initial_sigma) + if initial_sigma is not None and float(initial_sigma) > 0.0 + else max(window_width / 6.0, 1.0e-4) + ) + sigma_guess = min(max(sigma_guess, 1.0e-4), window_width) + intercept_guess = float(left_y + baseline_slope * (center_guess - left_r)) + positive_peak = np.maximum(residual_guess, 0.0) + if hasattr(np, "trapezoid"): + area_guess = float(np.trapezoid(positive_peak, fit_r)) + else: + area_guess = float( + np.sum( + 0.5 * (positive_peak[1:] + positive_peak[:-1]) * np.diff(fit_r) + ) + ) + if not np.isfinite(area_guess) or area_guess <= 0.0: + area_guess = ( + max(float(np.nanmax(fit_values) - np.nanmin(fit_values)), 1.0e-6) + * window_width + / 3.0 + ) + baseline_pivot = center_guess + + def model( + radial_values: np.ndarray, + area: float, + center: float, + sigma: float, + intercept: float, + slope: float, + ) -> np.ndarray: + return _coordination_gaussian_model( + radial_values, + area, + center, + sigma, + intercept, + slope, + baseline_pivot=baseline_pivot, + ) + + try: + params, _covariance = curve_fit( + model, + fit_r, + fit_values, + p0=[ + area_guess, + center_guess, + sigma_guess, + intercept_guess, + baseline_slope, + ], + bounds=( + [0.0, float(fit_r[0]), 1.0e-6, -np.inf, -np.inf], + [np.inf, float(fit_r[-1]), window_width * 2.0, np.inf, np.inf], + ), + maxfev=20000, + ) + except Exception as exc: + raise ValueError(f"R(r) coordination fit failed: {exc}") from exc + + fitted_values = model(fit_r, *params) + residual = fit_values - fitted_values + rmse = float(np.sqrt(np.mean(residual**2))) + total_variance = float(np.sum((fit_values - np.mean(fit_values)) ** 2)) + r_squared = ( + float("nan") + if total_variance <= 1.0e-20 + else 1.0 - float(np.sum(residual**2)) / total_variance + ) + area, center, sigma, intercept, slope = [float(value) for value in params] + amplitude = area / (sigma * math.sqrt(2.0 * math.pi)) + return DebyerCoordinationFitResult( + r_min=float(fit_r[0]), + r_max=float(fit_r[-1]), + center=center, + sigma=sigma, + coordination_number=area, + amplitude=float(amplitude), + baseline_intercept=intercept, + baseline_slope=slope, + rmse=rmse, + r_squared=r_squared, + point_count=int(fit_r.size), + fitted_values=np.asarray(fitted_values, dtype=float), + ) + + def classify_partial_pair( pair_label: str, *, @@ -869,6 +1265,20 @@ def classify_partial_pair( return "solute-solvent" +def _is_grouped_partial_column(column_name: str) -> bool: + return str(column_name) in GROUPED_PARTIAL_COLUMN_LABELS + + +def _raw_partial_values_from_output_values( + values: dict[str, np.ndarray], +) -> dict[str, np.ndarray]: + return { + key: np.asarray(value, dtype=float) + for key, value in values.items() + if not _is_grouped_partial_column(key) + } + + def build_grouped_partial_values( partial_values: dict[str, np.ndarray], *, @@ -879,6 +1289,8 @@ def build_grouped_partial_values( normalized_solutes = set(_normalize_solute_elements(solute_elements)) grouped: dict[str, np.ndarray] = {} for pair_label, values in partial_values.items(): + if _is_grouped_partial_column(pair_label): + continue family = classify_partial_pair( pair_label, solute_elements=normalized_solutes, @@ -893,6 +1305,137 @@ def build_grouped_partial_values( return grouped +def _output_values_with_grouped_partials( + *, + column_order: list[str], + values: dict[str, np.ndarray], + solute_elements: tuple[str, ...], +) -> tuple[list[str], dict[str, np.ndarray]]: + cleaned_order = [ + column + for column in column_order + if not _is_grouped_partial_column(column) + ] + cleaned_values = { + key: np.asarray(value, dtype=float) + for key, value in values.items() + if not _is_grouped_partial_column(key) + } + if not solute_elements: + return cleaned_order, cleaned_values + + raw_partials = { + key: value for key, value in cleaned_values.items() if key != "sum" + } + grouped = build_grouped_partial_values( + raw_partials, + solute_elements=solute_elements, + ) + output_values = dict(cleaned_values) + output_order = list(cleaned_order) + for label in GROUPED_PARTIAL_COLUMN_LABELS: + if label not in grouped: + continue + output_values[label] = np.asarray(grouped[label], dtype=float) + output_order.append(label) + return output_order, output_values + + +def _build_averaged_output_metadata_from_calculation( + calculation: DebyerPDFCalculation, +) -> dict[str, object]: + processed_frames = ( + calculation.frame_count + if calculation.processed_frame_count is None + else int(calculation.processed_frame_count) + ) + return { + "calculation_id": calculation.calculation_id, + "created_at": calculation.created_at, + "filename_prefix": calculation.filename_prefix, + "frames_dir": str(calculation.frames_dir), + "frame_format": calculation.frame_format, + "processed_frames": int(processed_frames), + "total_frames": int(calculation.frame_count), + "mode": calculation.mode, + "from_value": calculation.from_value, + "to_value": calculation.to_value, + "step_value": calculation.step_value, + "box_dimensions": ", ".join( + f"{component:.6g}" for component in calculation.box_dimensions + ), + "box_source": calculation.box_source or "estimated/manual", + "box_source_kind": calculation.box_source_kind or "estimate", + "atom_count": calculation.atom_count, + "rho0": f"{calculation.rho0:.8g}", + "store_frame_outputs": calculation.store_frame_outputs, + "solute_elements": (", ".join(calculation.solute_elements) or "None"), + "parallel_jobs": int(calculation.parallel_jobs), + "elapsed_seconds": ( + None + if calculation.elapsed_seconds is None + else f"{float(calculation.elapsed_seconds):.6f}" + ), + "estimated_remaining_seconds": ( + None + if calculation.estimated_remaining_seconds is None + else f"{float(calculation.estimated_remaining_seconds):.6f}" + ), + "expected_total_seconds": ( + None + if calculation.expected_total_seconds is None + else f"{float(calculation.expected_total_seconds):.6f}" + ), + "elapsed_hms": _format_duration(calculation.elapsed_seconds), + "remaining_hms": _format_duration( + calculation.estimated_remaining_seconds + ), + "expected_total_hms": _format_duration( + calculation.expected_total_seconds + ), + } + + +def rewrite_debyer_calculation_output( + calculation: DebyerPDFCalculation, +) -> None: + column_order = ["sum"] + if calculation.averaged_output_file.is_file(): + try: + parsed_order = _parse_columns_from_comments( + calculation.averaged_output_file + ) + except Exception: + parsed_order = [] + column_order = [ + column + for column in parsed_order + if column == "sum" or column in calculation.partial_values + ] or ["sum"] + for pair_label in sorted(calculation.partial_values): + if pair_label not in column_order: + column_order.append(pair_label) + values = { + "sum": np.asarray(calculation.total_values, dtype=float), + **{ + key: np.asarray(value, dtype=float) + for key, value in calculation.partial_values.items() + }, + } + output_order, output_values = _output_values_with_grouped_partials( + column_order=column_order, + values=values, + solute_elements=calculation.solute_elements, + ) + save_averaged_debyer_output( + calculation.averaged_output_file, + r_values=calculation.r_values, + column_order=output_order, + values=output_values, + metadata=_build_averaged_output_metadata_from_calculation(calculation), + ) + + def build_display_traces( calculation: DebyerPDFCalculation, *, @@ -967,10 +1510,7 @@ def load_debyer_calculation( averaged_output_file = Path(payload["averaged_output_file"]).resolve() r_values, raw_values = parse_debyer_output_file(averaged_output_file) total_values = np.asarray(raw_values.pop("sum"), dtype=float) - partial_values = { - key: np.asarray(value, dtype=float) - for key, value in raw_values.items() - } + partial_values = _raw_partial_values_from_output_values(raw_values) peak_finder_settings = _coerce_peak_finder_settings( payload.get("peak_finder_settings") ) @@ -1027,6 +1567,9 @@ def load_debyer_calculation( solute_elements=_normalize_solute_elements( payload.get("solute_elements", []) ), + parallel_jobs=_coerce_parallel_debyer_jobs( + payload.get("parallel_jobs", 1) + ), r_values=r_values, total_values=total_values, partial_values=partial_values, @@ -1120,6 +1663,9 @@ def __init__( solute_elements=_normalize_solute_elements( settings.solute_elements ), + max_parallel_jobs=_coerce_parallel_debyer_jobs( + settings.max_parallel_jobs + ), ) self.debyer_executable = ( None @@ -1143,6 +1689,42 @@ def inspect_frames(self) -> DebyerFrameInspection: ) return self._cached_inspection + def _run_debyer_frame( + self, + *, + frame_index: int, + frame_path: Path, + output_path: Path, + rho0: float, + executable_path: Path | None, + ) -> _DebyerFrameRunResult: + command = self._build_command( + input_file=frame_path, + output_file=output_path, + rho0=rho0, + executable_path=executable_path, + ) + completed = subprocess.run( + command, + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + if completed.returncode != 0: + raise RuntimeError( + f"Debyer failed on {frame_path.name}: " + + (completed.stderr.strip() or completed.stdout.strip()) + ) + frame_r_values, frame_values = parse_debyer_output_file(output_path) + return _DebyerFrameRunResult( + frame_index=frame_index, + frame_path=frame_path, + output_path=output_path, + r_values=frame_r_values, + values=frame_values, + ) + def run( self, *, @@ -1150,12 +1732,25 @@ def run( log_callback: Callable[[str], None] | None = None, status_callback: Callable[[str], None] | None = None, preview_callback: Callable[[DebyerPDFCalculation], None] | None = None, + preview_decision_callback: ( + Callable[[int, int, bool], bool] | None + ) = None, + cancel_callback: Callable[[], bool] | None = None, ) -> DebyerPDFCalculation: runtime_status = self.check_runtime() if not runtime_status.runnable: raise RuntimeError(runtime_status.message) inspection = self.inspect_frames() + if not self.settings.solute_elements: + inferred_solutes = infer_default_solute_elements( + inspection.element_counts + ) + if inferred_solutes: + self.settings = replace( + self.settings, + solute_elements=inferred_solutes, + ) calculation_id = _build_calculation_id(self.settings.filename_prefix) created_at = ( datetime.now(timezone.utc) @@ -1177,7 +1772,10 @@ def run( ) peak_finder_settings = DebyerPeakFinderSettings() total_frames = len(inspection.frame_paths) - prediction_interval = _time_prediction_interval(total_frames) + parallel_jobs = _resolve_parallel_debyer_jobs( + self.settings.max_parallel_jobs, + total_frames=total_frames, + ) if status_callback is not None: status_callback("Running Debyer over trajectory frames") if log_callback is not None: @@ -1185,6 +1783,10 @@ def run( "Starting Debyer " f"{self.settings.mode} calculation on {total_frames} frames" ) + log_callback( + f"Running up to {parallel_jobs} Debyer " + f"{'job' if parallel_jobs == 1 else 'jobs'} in parallel" + ) log_callback( "Bounding box: " + " x ".join( @@ -1194,81 +1796,127 @@ def run( + f" A; rho0={rho0:.6g} atoms/A^3" ) - averaged_inputs: list[tuple[np.ndarray, dict[str, np.ndarray]]] = [] + running_average = _RunningDebyerAverage() start_time = time.monotonic() last_verbose_log = time.monotonic() + checkpoint_interval = _average_checkpoint_interval( + total_frames=total_frames, + ) latest_preview: DebyerPDFCalculation | None = None - for index, frame_path in enumerate(inspection.frame_paths, start=1): + cancelled = False + frame_iterator = iter(enumerate(inspection.frame_paths, start=1)) + active_futures: dict[ + concurrent.futures.Future[_DebyerFrameRunResult], + int, + ] = {} + + def submit_next_frame( + executor: concurrent.futures.ThreadPoolExecutor, + ) -> bool: + try: + frame_index, frame_path = next(frame_iterator) + except StopIteration: + return False output_path = frame_output_dir / f"{frame_path.stem}.txt" - command = self._build_command( - input_file=frame_path, - output_file=output_path, + future = executor.submit( + self._run_debyer_frame, + frame_index=frame_index, + frame_path=frame_path, + output_path=output_path, rho0=rho0, executable_path=runtime_status.executable_path, ) - completed = subprocess.run( - command, - check=False, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, + active_futures[future] = frame_index + return True + + def process_frame_result( + frame_result: _DebyerFrameRunResult, + ) -> None: + nonlocal checkpoint_interval, latest_preview, last_verbose_log + running_average.add_frame( + frame_result.r_values, + frame_result.values, + ) + processed_frames = running_average.processed_count + checkpoint_interval = _average_checkpoint_interval( + total_frames=total_frames, + average_state_bytes=running_average.memory_bytes, ) - if completed.returncode != 0: - raise RuntimeError( - f"Debyer failed on {frame_path.name}: " - + (completed.stderr.strip() or completed.stdout.strip()) - ) - averaged_inputs.append(parse_debyer_output_file(output_path)) elapsed_seconds = time.monotonic() - start_time ( estimated_remaining_seconds, expected_total_seconds, ) = _estimate_runtime( - processed_frames=index, + processed_frames=processed_frames, total_frames=total_frames, elapsed_seconds=elapsed_seconds, ) - if not self.settings.store_frame_outputs and output_path.exists(): - output_path.unlink() + if ( + not self.settings.store_frame_outputs + and frame_result.output_path.exists() + ): + frame_result.output_path.unlink() if progress_callback is not None: progress_message = ( - f"Processed {index}/{total_frames} frames | " + f"Processed {processed_frames}/{total_frames} frames | " f"elapsed {_format_duration(elapsed_seconds)} | " f"remaining {_format_duration(estimated_remaining_seconds)}" ) progress_callback( - index, + processed_frames, total_frames, progress_message, ) - should_refresh_average = ( - index == 1 - or index == total_frames - or index % prediction_interval == 0 + checkpoint_due = ( + processed_frames == total_frames + or processed_frames % checkpoint_interval == 0 ) + should_refresh_average = checkpoint_due + if ( + preview_callback is not None + and preview_decision_callback is not None + ): + should_refresh_average = bool( + preview_decision_callback( + processed_frames, + total_frames, + checkpoint_due, + ) + ) if should_refresh_average: ( preview_r_values, preview_column_order, preview_values, - ) = _average_frame_outputs(averaged_inputs) + ) = running_average.average() + ( + preview_output_column_order, + preview_output_values, + ) = _output_values_with_grouped_partials( + column_order=preview_column_order, + values=preview_values, + solute_elements=self.settings.solute_elements, + ) save_averaged_debyer_output( averaged_output_file, r_values=preview_r_values, - column_order=preview_column_order, - values=preview_values, + column_order=preview_output_column_order, + values=preview_output_values, metadata=_build_averaged_output_metadata( calculation_id=calculation_id, created_at=created_at, settings=self.settings, inspection=inspection, rho0=rho0, - processed_frames=index, + processed_frames=processed_frames, total_frames=total_frames, elapsed_seconds=elapsed_seconds, - estimated_remaining_seconds=estimated_remaining_seconds, + estimated_remaining_seconds=( + estimated_remaining_seconds + ), expected_total_seconds=expected_total_seconds, + parallel_jobs=parallel_jobs, ), ) latest_preview = DebyerPDFCalculation( @@ -1293,6 +1941,7 @@ def run( frame_output_dir=frame_output_dir, averaged_output_file=averaged_output_file, solute_elements=self.settings.solute_elements, + parallel_jobs=parallel_jobs, r_values=preview_r_values, total_values=np.asarray( preview_values["sum"], @@ -1301,10 +1950,10 @@ def run( partial_values={ key: np.asarray(value, dtype=float) for key, value in preview_values.items() - if key != "sum" + if key != "sum" and not _is_grouped_partial_column(key) }, - processed_frame_count=index, - is_partial_average=index < total_frames, + processed_frame_count=processed_frames, + is_partial_average=processed_frames < total_frames, elapsed_seconds=elapsed_seconds, estimated_remaining_seconds=estimated_remaining_seconds, expected_total_seconds=expected_total_seconds, @@ -1316,19 +1965,64 @@ def run( preview_callback(latest_preview) if log_callback is not None: should_log = ( - index == 1 - or index == total_frames + processed_frames == 1 + or processed_frames == total_frames or (time.monotonic() - last_verbose_log) >= 5.0 ) if should_log: + checkpoint_text = ( + f"; checkpoint every {checkpoint_interval} frames" + if processed_frames == 1 + else "" + ) log_callback( - f"Processed {index}/{total_frames} frames " - f"({frame_path.name}) | elapsed " + f"Processed {processed_frames}/{total_frames} frames " + f"({frame_result.frame_path.name}) | elapsed " f"{_format_duration(elapsed_seconds)} | remaining " f"{_format_duration(estimated_remaining_seconds)}" + f"{checkpoint_text}" ) last_verbose_log = time.monotonic() + with concurrent.futures.ThreadPoolExecutor( + max_workers=parallel_jobs, + thread_name_prefix="debyer-pdf", + ) as executor: + for _worker_index in range(parallel_jobs): + if not submit_next_frame(executor): + break + while active_futures: + done_futures, _pending_futures = concurrent.futures.wait( + active_futures, + return_when=concurrent.futures.FIRST_COMPLETED, + ) + for future in done_futures: + active_futures.pop(future, None) + process_frame_result(future.result()) + if ( + cancel_callback is not None + and cancel_callback() + and running_average.processed_count < total_frames + ): + if not cancelled: + cancelled = True + if log_callback is not None: + active_count = len(active_futures) + suffix = ( + f" Waiting for {active_count} active " + "Debyer job(s) to finish." + if active_count + else "" + ) + log_callback( + "Debyer calculation stop requested; saving " + "the current average after " + f"{running_average.processed_count}/" + f"{total_frames} frames.{suffix}" + ) + if not cancelled: + submit_next_frame(executor) + if ( not self.settings.store_frame_outputs and frame_output_dir.is_dir() @@ -1339,38 +2033,46 @@ def run( else: stored_frame_output_dir = frame_output_dir + processed_frame_count = running_average.processed_count if ( latest_preview is None - or latest_preview.processed_frame_count != total_frames + or latest_preview.processed_frame_count != processed_frame_count ): elapsed_seconds = time.monotonic() - start_time ( estimated_remaining_seconds, expected_total_seconds, ) = _estimate_runtime( - processed_frames=total_frames, + processed_frames=processed_frame_count, total_frames=total_frames, elapsed_seconds=elapsed_seconds, ) - r_values, column_order, averaged_values = _average_frame_outputs( - averaged_inputs + r_values, column_order, averaged_values = running_average.average() + ( + output_column_order, + output_values, + ) = _output_values_with_grouped_partials( + column_order=column_order, + values=averaged_values, + solute_elements=self.settings.solute_elements, ) save_averaged_debyer_output( averaged_output_file, r_values=r_values, - column_order=column_order, - values=averaged_values, + column_order=output_column_order, + values=output_values, metadata=_build_averaged_output_metadata( calculation_id=calculation_id, created_at=created_at, settings=self.settings, inspection=inspection, rho0=rho0, - processed_frames=total_frames, + processed_frames=processed_frame_count, total_frames=total_frames, elapsed_seconds=elapsed_seconds, - estimated_remaining_seconds=estimated_remaining_seconds, + estimated_remaining_seconds=(estimated_remaining_seconds), expected_total_seconds=expected_total_seconds, + parallel_jobs=parallel_jobs, ), ) else: @@ -1386,10 +2088,9 @@ def run( final_total_values = np.asarray( final_raw_values.pop("sum"), dtype=float ) - final_partial_values = { - key: np.asarray(value, dtype=float) - for key, value in final_raw_values.items() - } + final_partial_values = _raw_partial_values_from_output_values( + final_raw_values + ) final_calculation = DebyerPDFCalculation( calculation_id=calculation_id, calculation_dir=calculation_dir, @@ -1412,11 +2113,12 @@ def run( frame_output_dir=stored_frame_output_dir, averaged_output_file=averaged_output_file, solute_elements=self.settings.solute_elements, + parallel_jobs=parallel_jobs, r_values=final_r_values, total_values=final_total_values, partial_values=final_partial_values, - processed_frame_count=total_frames, - is_partial_average=False, + processed_frame_count=processed_frame_count, + is_partial_average=processed_frame_count < total_frames, elapsed_seconds=elapsed_seconds, estimated_remaining_seconds=estimated_remaining_seconds, expected_total_seconds=expected_total_seconds, @@ -1434,7 +2136,11 @@ def run( f"Saved averaged Debyer output to {averaged_output_file}" ) if status_callback is not None: - status_callback("Debyer calculation complete") + status_callback( + "Debyer calculation stopped early" + if cancelled + else "Debyer calculation complete" + ) return final_calculation def _build_command( @@ -1471,8 +2177,11 @@ def _build_command( __all__ = [ "DEBYER_DOCS_URL", "DEBYER_GITHUB_URL", - "TOTAL_SCATTERING_PAPER_URL", "DEFAULT_COLOR_SCHEMES", + "GROUPED_PARTIAL_COLUMN_LABELS", + "TOTAL_SCATTERING_PAPER_URL", + "DebyerCoordinationFitResult", + "DebyerFitMetrics", "DebyerFrameInspection", "DebyerPeakFinderSettings", "DebyerPeakMarker", @@ -1489,13 +2198,18 @@ def _build_command( "calculate_number_density", "check_debyer_runtime", "classify_partial_pair", + "compute_experimental_fit_metrics", "convert_distribution_values", + "default_parallel_debyer_jobs", "estimate_partial_peak_markers", "find_partial_peak_markers", + "fit_coordination_peak_from_r", + "infer_default_solute_elements", "inspect_frames_dir", "list_saved_debyer_calculations", "load_debyer_calculation", "parse_debyer_output_file", + "rewrite_debyer_calculation_output", "save_averaged_debyer_output", "write_debyer_calculation_metadata", ] diff --git a/src/saxshell/representativefinder/ui/__init__.py b/src/saxshell/representativefinder/ui/__init__.py index 5122948..b898c39 100644 --- a/src/saxshell/representativefinder/ui/__init__.py +++ b/src/saxshell/representativefinder/ui/__init__.py @@ -1,11 +1,17 @@ """Qt UI for representative-structure screening.""" +from .batch_queue_window import ( + RepresentativeFinderBatchQueueWindow, + launch_representativefinder_batch_queue_ui, +) from .main_window import ( RepresentativeStructureFinderMainWindow, launch_representativefinder_ui, ) __all__ = [ + "RepresentativeFinderBatchQueueWindow", "RepresentativeStructureFinderMainWindow", + "launch_representativefinder_batch_queue_ui", "launch_representativefinder_ui", ] diff --git a/src/saxshell/representativefinder/ui/batch_queue_window.py b/src/saxshell/representativefinder/ui/batch_queue_window.py new file mode 100644 index 0000000..c636724 --- /dev/null +++ b/src/saxshell/representativefinder/ui/batch_queue_window.py @@ -0,0 +1,1171 @@ +from __future__ import annotations + +import threading +import time +import uuid +from dataclasses import dataclass, replace +from pathlib import Path + +from PySide6.QtCore import QObject, Qt, QThread, Signal, Slot +from PySide6.QtWidgets import ( + QAbstractItemView, + QApplication, + QComboBox, + QFileDialog, + QFormLayout, + QFrame, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QListView, + QListWidget, + QListWidgetItem, + QMainWindow, + QMessageBox, + QProgressBar, + QPushButton, + QSizePolicy, + QTextEdit, + QToolButton, + QTreeView, + QVBoxLayout, + QWidget, +) + +from saxshell.bondanalysis import ( + BondAnalysisPreset, + load_presets, + ordered_preset_names, +) +from saxshell.representativefinder.run_config import ( + RepresentativeFinderRunConfig, + build_representativefinder_run_config, + run_representativefinder_run_config, +) +from saxshell.representativefinder.workflow import ( + RepresentativeFinderSettings, + suggest_representativefinder_output_dir, +) +from saxshell.saxs.project_manager import ( + SAXSProjectManager, + build_project_paths, +) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) + + +def _new_item_id() -> str: + return uuid.uuid4().hex + + +def _optional_path(text: str) -> Path | None: + stripped = text.strip() + if not stripped: + return None + return Path(stripped).expanduser().resolve() + + +def _required_path(text: str, field_name: str) -> Path: + path = _optional_path(text) + if path is None: + raise ValueError(f"{field_name} is required.") + return path + + +def _required_project_dir(text: str) -> Path: + project_dir = _required_path(text, "Project folder") + project_file = build_project_paths(project_dir).project_file + if not project_file.is_file(): + raise ValueError(f"Project file does not exist: {project_file}") + return project_dir + + +def _required_clusters_dir(text: str) -> Path: + clusters_dir = _required_path(text, "Project clusters folder") + if not clusters_dir.is_dir(): + raise ValueError( + f"Project clusters folder does not exist: {clusters_dir}" + ) + return clusters_dir + + +def _dialog_start_dir(*candidates: str | Path | None) -> str: + for candidate in candidates: + if candidate is None: + continue + path = Path(candidate).expanduser() + if path.is_file(): + return str(path.parent) + if path.is_dir(): + return str(path) + return str(Path.home()) + + +def _choose_existing_directories( + parent: QWidget, + *, + title: str, + start_dir: str | Path, +) -> tuple[Path, ...]: + dialog = QFileDialog(parent, title, str(start_dir)) + dialog.setFileMode(QFileDialog.FileMode.Directory) + dialog.setOption(QFileDialog.Option.ShowDirsOnly, True) + dialog.setOption(QFileDialog.Option.DontUseNativeDialog, True) + for view in dialog.findChildren(QListView) + dialog.findChildren( + QTreeView + ): + view.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + if dialog.exec() != int(QFileDialog.DialogCode.Accepted): + return () + return tuple( + Path(path).expanduser().resolve() for path in dialog.selectedFiles() + ) + + +def _project_reference_text(project_dir: Path | None) -> str: + if project_dir is None: + return "Project reference: choose a SAXSShell project folder." + project_file = build_project_paths(project_dir).project_file + if project_file.is_file(): + return f"Project reference: {project_file}" + return f"Project reference: no project file found at {project_file}" + + +def _suggest_batch_output_dir(project_dir: Path, clusters_dir: Path) -> Path: + return suggest_representativefinder_output_dir( + clusters_dir, + project_dir=project_dir, + batch=True, + ) + + +def _settings_from_preset( + preset: BondAnalysisPreset, +) -> RepresentativeFinderSettings: + return RepresentativeFinderSettings( + bond_pairs=tuple(preset.bond_pairs), + angle_triplets=tuple(preset.angle_triplets), + ) + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderBatchJob: + project_dir: Path + clusters_dir: Path + output_dir: Path + config: RepresentativeFinderRunConfig + + +@dataclass(slots=True) +class RepresentativeFinderBatchResult: + project_dir: Path + clusters_dir: Path + output_dir: Path + completed_count: int + failed_count: int + skipped_count: int + + +@dataclass(slots=True) +class RepresentativeFinderBatchItem: + item_id: str + project_dir: Path | None = None + clusters_dir: Path | None = None + output_dir: Path | None = None + + def display_name(self) -> str: + if self.project_dir is not None: + return self.project_dir.name + if self.clusters_dir is not None: + return self.clusters_dir.name + return "New representative analysis" + + +def _queue_item_from_project_defaults( + project_dir: str | Path, + *, + item_id: str | None = None, +) -> RepresentativeFinderBatchItem: + resolved_project_dir = Path(project_dir).expanduser().resolve() + item = RepresentativeFinderBatchItem( + item_id=item_id or _new_item_id(), + project_dir=resolved_project_dir, + ) + try: + settings = SAXSProjectManager().load_project(resolved_project_dir) + except Exception: + return item + clusters_dir = settings.resolved_clusters_dir + output_dir = None + if clusters_dir is not None and clusters_dir.is_dir(): + try: + output_dir = _suggest_batch_output_dir( + resolved_project_dir, + clusters_dir, + ) + except Exception: + output_dir = None + return replace(item, clusters_dir=clusters_dir, output_dir=output_dir) + + +class RepresentativeFinderBatchItemWidget(QFrame): + settings_changed = Signal(str) + remove_requested = Signal(str) + duplicate_requested = Signal(str) + + def __init__( + self, + item: RepresentativeFinderBatchItem, + *, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._item = item + self._loading = False + self._selected = False + self._last_suggested_output_dir: Path | None = None + self._build_ui() + self._load_item(item) + self._set_settings_visible(False) + + @property + def item_id(self) -> str: + return self._item.item_id + + def item(self) -> RepresentativeFinderBatchItem: + return self._item + + def collect_item(self) -> RepresentativeFinderBatchItem: + self._item = RepresentativeFinderBatchItem( + item_id=self._item.item_id, + project_dir=_optional_path(self.project_dir_edit.text()), + clusters_dir=_optional_path(self.clusters_dir_edit.text()), + output_dir=_optional_path(self.output_dir_edit.text()), + ) + self._refresh_header() + self._refresh_project_reference() + return self._item + + def job( + self, + *, + settings: RepresentativeFinderSettings, + ) -> RepresentativeFinderBatchJob: + self.collect_item() + project_dir = _required_project_dir(self.project_dir_edit.text()) + clusters_dir = _required_clusters_dir(self.clusters_dir_edit.text()) + output_dir = _optional_path( + self.output_dir_edit.text() + ) or _suggest_batch_output_dir(project_dir, clusters_dir) + config = build_representativefinder_run_config( + project_dir=project_dir, + input_dir=clusters_dir, + output_dir=output_dir, + analysis_mode="all", + settings=settings, + overwrite_existing=False, + ) + return RepresentativeFinderBatchJob( + project_dir=project_dir, + clusters_dir=clusters_dir, + output_dir=output_dir, + config=config, + ) + + def set_locked(self, locked: bool) -> None: + self.settings_group.setEnabled(not locked) + self.validate_button.setEnabled(not locked) + self.duplicate_button.setEnabled(not locked) + self.remove_button.setEnabled(not locked) + + def set_status(self, message: str) -> None: + self.status_label.setText(message) + + def set_progress(self, processed: int, total: int) -> None: + self.progress_bar.setRange(0, max(int(total), 1)) + self.progress_bar.setValue(max(int(processed), 0)) + + def set_selected(self, selected: bool) -> None: + self._selected = bool(selected) + self.header_frame.setProperty("selected", self._selected) + self.header_frame.setStyleSheet( + "QFrame#RepresentativeFinderBatchItemHeader {" + + ( + "background-color: #dce8f7; " "border: 1px solid #8fb0d7;" + if self._selected + else "background-color: #f6f8fb; " "border: 1px solid #cfd7e3;" + ) + + "border-radius: 5px;}" + ) + + def validate_paths(self) -> None: + project_dir = _required_project_dir(self.project_dir_edit.text()) + clusters_dir = _required_clusters_dir(self.clusters_dir_edit.text()) + suggested = _suggest_batch_output_dir(project_dir, clusters_dir) + current = _optional_path(self.output_dir_edit.text()) + if current is None or current == self._last_suggested_output_dir: + self.output_dir_edit.setText(str(suggested)) + self._last_suggested_output_dir = suggested + self.set_status("Ready") + + def _build_ui(self) -> None: + self.setFrameShape(QFrame.Shape.StyledPanel) + self.setSizePolicy( + QSizePolicy.Policy.Expanding, + QSizePolicy.Policy.Fixed, + ) + root = QVBoxLayout(self) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + self.header_frame = QFrame() + self.header_frame.setObjectName("RepresentativeFinderBatchItemHeader") + header = QHBoxLayout(self.header_frame) + header.setContentsMargins(8, 6, 8, 6) + header.setSpacing(8) + self.toggle_button = QToolButton() + self.toggle_button.setCheckable(True) + self.toggle_button.toggled.connect(self._set_settings_visible) + header.addWidget(self.toggle_button) + self.title_label = QLabel("New representative analysis") + self.title_label.setStyleSheet("font-weight: 600;") + header.addWidget(self.title_label, stretch=1) + self.status_label = QLabel("Ready") + self.status_label.setMinimumWidth(190) + header.addWidget(self.status_label) + self.validate_button = QPushButton("Validate") + self.validate_button.clicked.connect(self._validate_from_button) + header.addWidget(self.validate_button) + self.duplicate_button = QPushButton("Duplicate") + self.duplicate_button.clicked.connect( + lambda: self.duplicate_requested.emit(self.item_id) + ) + header.addWidget(self.duplicate_button) + self.remove_button = QPushButton("Remove") + self.remove_button.clicked.connect( + lambda: self.remove_requested.emit(self.item_id) + ) + header.addWidget(self.remove_button) + root.addWidget(self.header_frame) + self.set_selected(False) + + self.progress_bar = QProgressBar() + self.progress_bar.setRange(0, 1) + self.progress_bar.setValue(0) + self.progress_bar.setFormat("%v / %m steps") + root.addWidget(self.progress_bar) + + self.settings_group = QGroupBox( + "Representative Structure Batch Settings" + ) + root.addWidget(self.settings_group) + settings_layout = QVBoxLayout(self.settings_group) + + form = QFormLayout() + self.project_dir_edit = QLineEdit() + self.project_dir_edit.editingFinished.connect(self._on_project_changed) + form.addRow( + "Project folder", + self._path_row(self.project_dir_edit, self._choose_project_dir), + ) + self.project_reference_label = QLabel() + self.project_reference_label.setWordWrap(True) + self.project_reference_label.setFrameShape(QFrame.Shape.StyledPanel) + form.addRow("", self.project_reference_label) + + self.clusters_dir_edit = QLineEdit() + self.clusters_dir_edit.editingFinished.connect( + self._on_clusters_changed + ) + form.addRow( + "Project clusters folder", + self._path_row(self.clusters_dir_edit, self._choose_clusters_dir), + ) + + self.output_dir_edit = QLineEdit() + self.output_dir_edit.editingFinished.connect(self._on_editor_changed) + form.addRow( + "Output root", + self._path_row(self.output_dir_edit, self._choose_output_dir), + ) + self.analysis_mode_label = QLabel("All Discovered Stoichiometries") + form.addRow("Analysis mode", self.analysis_mode_label) + settings_layout.addLayout(form) + + def _path_row(self, edit: QLineEdit, slot) -> QWidget: + row_widget = QWidget() + row = QHBoxLayout(row_widget) + row.setContentsMargins(0, 0, 0, 0) + row.addWidget(edit, stretch=1) + button = QPushButton("Browse...") + button.clicked.connect(slot) + row.addWidget(button) + return row_widget + + def _load_item(self, item: RepresentativeFinderBatchItem) -> None: + self._loading = True + self.project_dir_edit.setText( + "" if item.project_dir is None else str(item.project_dir) + ) + self.clusters_dir_edit.setText( + "" if item.clusters_dir is None else str(item.clusters_dir) + ) + self.output_dir_edit.setText( + "" if item.output_dir is None else str(item.output_dir) + ) + if item.output_dir is not None: + self._last_suggested_output_dir = item.output_dir + self._loading = False + self._refresh_header() + self._refresh_project_reference() + self._validate_quietly() + + def _set_settings_visible(self, visible: bool) -> None: + self.settings_group.setVisible(bool(visible)) + self.toggle_button.setChecked(bool(visible)) + self.toggle_button.setText("Hide Settings" if visible else "Settings") + parent_item = self._list_item() + if parent_item is not None: + parent_item.setSizeHint(self.sizeHint()) + + def _list_item(self) -> QListWidgetItem | None: + parent = self.parent() + while parent is not None and not isinstance(parent, QListWidget): + parent = parent.parent() + if not isinstance(parent, QListWidget): + return None + for row in range(parent.count()): + list_item = parent.item(row) + if parent.itemWidget(list_item) is self: + return list_item + return None + + def _choose_project_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select SAXSShell project folder", + _dialog_start_dir(self.project_dir_edit.text()), + ) + if not selected: + return + self._load_item( + _queue_item_from_project_defaults( + selected, + item_id=self.item_id, + ) + ) + self._on_editor_changed() + + def _choose_clusters_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select project clusters folder", + _dialog_start_dir( + self.clusters_dir_edit.text(), + self.project_dir_edit.text(), + ), + ) + if not selected: + return + self.clusters_dir_edit.setText(selected) + self._on_clusters_changed() + + def _choose_output_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select representative output root", + _dialog_start_dir( + self.output_dir_edit.text(), + self.clusters_dir_edit.text(), + ), + ) + if not selected: + return + self.output_dir_edit.setText(selected) + self._on_editor_changed() + + def _on_project_changed(self) -> None: + project_dir = _optional_path(self.project_dir_edit.text()) + if project_dir is None: + self._on_editor_changed() + return + item = _queue_item_from_project_defaults( + project_dir, + item_id=self.item_id, + ) + if item.clusters_dir is not None: + self.clusters_dir_edit.setText(str(item.clusters_dir)) + if item.output_dir is not None: + self.output_dir_edit.setText(str(item.output_dir)) + self._last_suggested_output_dir = item.output_dir + self._validate_quietly() + self._on_editor_changed() + + def _on_clusters_changed(self) -> None: + clusters_dir = _optional_path(self.clusters_dir_edit.text()) + project_dir = _optional_path(self.project_dir_edit.text()) + if clusters_dir is not None and project_dir is not None: + try: + suggested = _suggest_batch_output_dir( + project_dir, clusters_dir + ) + current = _optional_path(self.output_dir_edit.text()) + if ( + current is None + or current == self._last_suggested_output_dir + ): + self.output_dir_edit.setText(str(suggested)) + self._last_suggested_output_dir = suggested + except Exception: + pass + self._validate_quietly() + self._on_editor_changed() + + def _validate_from_button(self) -> None: + try: + self.validate_paths() + self._on_editor_changed() + except Exception as exc: + QMessageBox.warning( + self, + "Unable to validate representative batch item", + str(exc), + ) + self.set_status("Validation failed") + self._on_editor_changed() + + def _validate_quietly(self) -> None: + if not self.clusters_dir_edit.text().strip(): + return + try: + self.validate_paths() + except Exception: + self.set_status("Validation failed") + + def _on_editor_changed(self, *_args) -> None: + if self._loading: + return + try: + self.collect_item() + if self.status_label.text() in {"Validation failed", "Failed"}: + self.set_status("Ready") + except Exception: + self._refresh_header() + self._refresh_project_reference() + self.settings_changed.emit(self.item_id) + + def _refresh_header(self) -> None: + self.title_label.setText(self._item.display_name()) + + def _refresh_project_reference(self) -> None: + self.project_reference_label.setText( + _project_reference_text( + _optional_path(self.project_dir_edit.text()) + ) + ) + + +class RepresentativeFinderBatchWorker(QObject): + item_started = Signal(str, int, int) + item_progress = Signal(str, int, int, str) + item_finished = Signal(str, object) + item_failed = Signal(str, str) + log = Signal(str) + status = Signal(str) + project_results_changed = Signal(str) + finished = Signal(object) + failed = Signal(str, str) + + def __init__( + self, + queue_entries: list[tuple[str, RepresentativeFinderBatchJob]], + ) -> None: + super().__init__() + self.queue_entries = list(queue_entries) + self._cancel_requested = threading.Event() + + def request_cancel(self) -> None: + self._cancel_requested.set() + + @Slot() + def run(self) -> None: + results: list[RepresentativeFinderBatchResult] = [] + total_items = len(self.queue_entries) + for index, (item_id, job) in enumerate( + self.queue_entries, + start=1, + ): + if self._cancel_requested.is_set(): + self.log.emit("Batch queue stopped before the next project.") + break + self.item_started.emit(item_id, index, total_items) + self.status.emit( + f"Running {index}/{total_items}: {job.project_dir.name}" + ) + try: + result = self._run_job(item_id, job) + except Exception as exc: + message = str(exc) + self.item_failed.emit(item_id, message) + self.failed.emit(item_id, message) + return + results.append(result) + self.item_finished.emit(item_id, result) + self.project_results_changed.emit(str(result.project_dir)) + self.status.emit("Representative structure batch queue finished") + self.finished.emit(results) + + def _run_job( + self, + item_id: str, + job: RepresentativeFinderBatchJob, + ) -> RepresentativeFinderBatchResult: + self.log.emit( + f"[{job.project_dir.name}] Starting representative analysis." + ) + last_progress_emit = 0.0 + + def on_progress(processed: int, total: int, message: str) -> None: + nonlocal last_progress_emit + now = time.monotonic() + is_terminal_update = total > 0 and processed >= total + if not is_terminal_update and now - last_progress_emit < 0.15: + return + last_progress_emit = now + self.item_progress.emit(item_id, processed, total, message) + + def on_log(message: str) -> None: + self.log.emit(f"[{job.project_dir.name}] {message}") + + summary = run_representativefinder_run_config( + job.project_dir, + job.config, + log_callback=on_log, + progress_callback=on_progress, + ) + result = RepresentativeFinderBatchResult( + project_dir=job.project_dir, + clusters_dir=job.clusters_dir, + output_dir=job.output_dir, + completed_count=summary.completed_count, + failed_count=summary.failed_count, + skipped_count=len(summary.skipped_existing), + ) + if result.failed_count: + self.log.emit( + f"[{job.project_dir.name}] Completed with " + f"{result.failed_count} failed stoichiometry run(s)." + ) + else: + self.log.emit( + f"[{job.project_dir.name}] Completed " + f"{result.completed_count} representative selection(s)." + ) + return result + + +class RepresentativeFinderBatchQueueWindow(QMainWindow): + """Queue representative-structure analysis for multiple projects.""" + + project_results_changed = Signal(str) + + def __init__( + self, + initial_project_dir: str | Path | None = None, + *, + initial_clusters_dir: str | Path | None = None, + initial_input_path: str | Path | None = None, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + if initial_clusters_dir is None: + initial_clusters_dir = initial_input_path + self._widgets_by_id: dict[str, RepresentativeFinderBatchItemWidget] = ( + {} + ) + self._run_thread: QThread | None = None + self._run_worker: RepresentativeFinderBatchWorker | None = None + self._presets: dict[str, BondAnalysisPreset] = {} + self._initial_project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + self._initial_clusters_dir = ( + None + if initial_clusters_dir is None + else Path(initial_clusters_dir).expanduser().resolve() + ) + self._build_ui() + self._reload_presets() + if ( + self._initial_project_dir is not None + or self._initial_clusters_dir is not None + ): + self._add_current_project() + + def closeEvent(self, event) -> None: + if self._run_thread is not None and self._run_thread.isRunning(): + self._request_cancel() + self.hide() + while ( + self._run_thread is not None and self._run_thread.isRunning() + ): + QApplication.processEvents() + if self._run_thread is not None: + self._run_thread.wait(50) + event.accept() + return + super().closeEvent(event) + + def add_queue_item( + self, + item: RepresentativeFinderBatchItem | None = None, + ) -> RepresentativeFinderBatchItemWidget: + resolved_item = item or RepresentativeFinderBatchItem( + item_id=_new_item_id() + ) + list_item = QListWidgetItem() + list_item.setData(Qt.ItemDataRole.UserRole, resolved_item.item_id) + self.queue_list.addItem(list_item) + widget = RepresentativeFinderBatchItemWidget( + resolved_item, + parent=self.queue_list, + ) + widget.settings_changed.connect(self._on_item_settings_changed) + widget.remove_requested.connect(self._remove_item) + widget.duplicate_requested.connect(self._duplicate_item) + self._widgets_by_id[resolved_item.item_id] = widget + list_item.setSizeHint(widget.sizeHint()) + self.queue_list.setItemWidget(list_item, widget) + self.queue_list.setCurrentItem(list_item) + self._refresh_order_labels() + return widget + + def queue_jobs_in_order( + self, + ) -> list[tuple[str, RepresentativeFinderBatchJob]]: + preset = self._selected_preset() + if preset is None: + raise ValueError("Choose a bondanalysis preset before running.") + settings = _settings_from_preset(preset) + entries: list[tuple[str, RepresentativeFinderBatchJob]] = [] + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id[item_id] + entries.append((item_id, widget.job(settings=settings))) + return entries + + def _build_ui(self) -> None: + self.setWindowTitle("SAXSShell Representative Structures Batch Queue") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1120, 840) + + central = QWidget() + root = QVBoxLayout(central) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + controls = QHBoxLayout() + self.add_current_button = QPushButton("Add Current Project") + self.add_current_button.clicked.connect(self._add_current_project) + controls.addWidget(self.add_current_button) + self.add_project_button = QPushButton("Add Projects...") + self.add_project_button.clicked.connect(self._choose_projects_to_add) + controls.addWidget(self.add_project_button) + controls.addStretch(1) + root.addLayout(controls) + + preset_group = QGroupBox("Batch Settings") + preset_layout = QFormLayout(preset_group) + preset_row = QHBoxLayout() + self.preset_combo = QComboBox() + self.preset_combo.currentIndexChanged.connect( + lambda _index: self._refresh_preset_summary() + ) + preset_row.addWidget(self.preset_combo, stretch=1) + self.reload_presets_button = QPushButton("Reload Presets") + self.reload_presets_button.clicked.connect(self._reload_presets) + preset_row.addWidget(self.reload_presets_button) + preset_widget = QWidget() + preset_widget.setLayout(preset_row) + preset_layout.addRow("Bondanalysis preset", preset_widget) + preset_layout.addRow( + "Analysis mode", + QLabel("All Discovered Stoichiometries"), + ) + self.preset_summary_label = QLabel() + self.preset_summary_label.setWordWrap(True) + preset_layout.addRow("", self.preset_summary_label) + root.addWidget(preset_group) + + self.queue_list = QListWidget() + self.queue_list.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.queue_list.setDragDropMode( + QAbstractItemView.DragDropMode.InternalMove + ) + self.queue_list.setDefaultDropAction(Qt.DropAction.MoveAction) + self.queue_list.setAlternatingRowColors(True) + self.queue_list.setStyleSheet( + "QListWidget::item:selected { background: transparent; }" + "QListWidget::item:hover { background: transparent; }" + "QListWidget::item { margin: 3px; }" + ) + self.queue_list.model().rowsMoved.connect(self._refresh_order_labels) + self.queue_list.itemSelectionChanged.connect( + self._refresh_item_selection_styles + ) + root.addWidget(self.queue_list, stretch=1) + + run_group = QGroupBox("Execute Queue") + run_layout = QVBoxLayout(run_group) + run_buttons = QHBoxLayout() + self.run_button = QPushButton("Run Complete Queue") + self.run_button.clicked.connect(self._start_queue) + run_buttons.addWidget(self.run_button) + self.cancel_button = QPushButton("Stop Queue") + self.cancel_button.setEnabled(False) + self.cancel_button.clicked.connect(self._request_cancel) + run_buttons.addWidget(self.cancel_button) + run_buttons.addStretch(1) + run_layout.addLayout(run_buttons) + self.queue_status_label = QLabel("Queue idle") + run_layout.addWidget(self.queue_status_label) + self.console = QTextEdit() + self.console.setReadOnly(True) + self.console.setMinimumHeight(150) + run_layout.addWidget(self.console) + root.addWidget(run_group) + + self.setCentralWidget(central) + self.statusBar().showMessage("Ready") + + def _reload_presets(self) -> None: + current_name = self._selected_preset_name() + self._presets = load_presets() + self.preset_combo.blockSignals(True) + self.preset_combo.clear() + selected_index = 0 + for index, name in enumerate(ordered_preset_names(self._presets)): + preset = self._presets[name] + label = f"{name} (Built-in)" if preset.builtin else name + self.preset_combo.addItem(label, name) + if name == current_name: + selected_index = index + if self.preset_combo.count() > 0: + self.preset_combo.setCurrentIndex(selected_index) + self.preset_combo.blockSignals(False) + self._refresh_preset_summary() + + def _selected_preset_name(self) -> str | None: + if not hasattr(self, "preset_combo"): + return None + payload = self.preset_combo.currentData() + return None if payload is None else str(payload) + + def _selected_preset(self) -> BondAnalysisPreset | None: + preset_name = self._selected_preset_name() + if preset_name is None: + return None + return self._presets.get(preset_name) + + def _refresh_preset_summary(self) -> None: + preset = self._selected_preset() + if preset is None: + self.preset_summary_label.setText( + "No bondanalysis preset is selected." + ) + return + self.preset_summary_label.setText( + f"Using {preset.name}: {len(preset.bond_pairs)} bond pair(s), " + f"{len(preset.angle_triplets)} angle triplet(s). Advanced " + "representative scoring and solvent shell builder settings use " + "their defaults." + ) + + def _add_current_project(self) -> None: + if ( + self._initial_project_dir is None + and self._initial_clusters_dir is None + ): + QMessageBox.information( + self, + "No active project", + "The main UI did not provide an active project reference.", + ) + return + item = ( + _queue_item_from_project_defaults(self._initial_project_dir) + if self._initial_project_dir is not None + else RepresentativeFinderBatchItem(item_id=_new_item_id()) + ) + if self._initial_clusters_dir is not None: + output_dir = None + if self._initial_project_dir is not None: + try: + output_dir = _suggest_batch_output_dir( + self._initial_project_dir, + self._initial_clusters_dir, + ) + except Exception: + output_dir = item.output_dir + item = replace( + item, + clusters_dir=self._initial_clusters_dir, + output_dir=output_dir or item.output_dir, + ) + self.add_queue_item(item) + + def _choose_projects_to_add(self) -> None: + selected_dirs = _choose_existing_directories( + self, + title="Select SAXSShell project folders", + start_dir=self._initial_project_dir or Path.home(), + ) + if not selected_dirs: + return + for project_dir in selected_dirs: + self.add_queue_item(_queue_item_from_project_defaults(project_dir)) + + def _on_item_settings_changed(self, _item_id: str) -> None: + self._refresh_order_labels() + + def _refresh_order_labels(self, *_args) -> None: + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id.get(item_id) + if widget is None: + continue + widget.title_label.setText( + f"{row + 1}. {widget.item().display_name()}" + ) + list_item.setSizeHint(widget.sizeHint()) + self._refresh_item_selection_styles() + + def _refresh_item_selection_styles(self) -> None: + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_selected(list_item.isSelected()) + + def _remove_item(self, item_id: str) -> None: + if self._run_thread is not None and self._run_thread.isRunning(): + return + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + if str(list_item.data(Qt.ItemDataRole.UserRole)) == item_id: + self.queue_list.takeItem(row) + break + self._widgets_by_id.pop(item_id, None) + self._refresh_order_labels() + + def _duplicate_item(self, item_id: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is None: + return + try: + item = widget.collect_item() + except Exception: + item = widget.item() + self.add_queue_item(replace(item, item_id=_new_item_id())) + + def _set_running(self, running: bool) -> None: + self.add_current_button.setEnabled(not running) + self.add_project_button.setEnabled(not running) + self.reload_presets_button.setEnabled(not running) + self.preset_combo.setEnabled(not running) + self.run_button.setEnabled(not running) + self.cancel_button.setEnabled(running) + self.queue_list.setDragEnabled(not running) + self.queue_list.setAcceptDrops(not running) + for widget in self._widgets_by_id.values(): + widget.set_locked(running) + + def _start_queue(self) -> None: + if self.queue_list.count() == 0: + QMessageBox.information( + self, + "Representative structure batch queue", + "Add at least one project before running the queue.", + ) + return + try: + entries = self.queue_jobs_in_order() + except Exception as exc: + QMessageBox.warning( + self, + "Invalid representative structure batch settings", + str(exc), + ) + return + + self.console.clear() + self._set_running(True) + self.queue_status_label.setText( + f"Running 0/{len(entries)} queued project(s)" + ) + for widget in self._widgets_by_id.values(): + widget.set_progress(0, 1) + widget.set_status("Queued") + + self._run_thread = QThread(self) + self._run_worker = RepresentativeFinderBatchWorker(entries) + self._run_worker.moveToThread(self._run_thread) + self._run_thread.started.connect(self._run_worker.run) + self._run_worker.item_started.connect(self._on_item_started) + self._run_worker.item_progress.connect(self._on_item_progress) + self._run_worker.item_finished.connect(self._on_item_finished) + self._run_worker.item_failed.connect(self._on_item_failed) + self._run_worker.log.connect(self._append_log) + self._run_worker.status.connect(self._on_status) + self._run_worker.project_results_changed.connect( + self.project_results_changed.emit + ) + self._run_worker.finished.connect(self._on_queue_finished) + self._run_worker.failed.connect(self._on_queue_failed) + self._run_worker.finished.connect(self._run_thread.quit) + self._run_worker.failed.connect(self._run_thread.quit) + self._run_thread.finished.connect(self._cleanup_run_thread) + self._run_thread.finished.connect(self._run_thread.deleteLater) + self._run_thread.start() + + def _request_cancel(self) -> None: + self.cancel_button.setEnabled(False) + self.queue_status_label.setText( + "Stopping queue after the active project finishes" + ) + self._append_log( + "Stop requested; the current project will finish before the " + "queue exits." + ) + if self._run_worker is not None: + self._run_worker.request_cancel() + + def _append_log(self, message: str) -> None: + self.console.append(message) + + def _on_status(self, message: str) -> None: + self.statusBar().showMessage(message) + self.queue_status_label.setText(message) + + def _on_item_started( + self, + item_id: str, + index: int, + total: int, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status(f"Running {index}/{total}") + widget.set_progress(0, 1) + self.queue_status_label.setText( + f"Running {index}/{total} queued project(s)" + ) + + def _on_item_progress( + self, + item_id: str, + processed: int, + total: int, + message: str, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_progress(processed, total) + widget.set_status(message) + + def _on_item_finished( + self, + item_id: str, + result: RepresentativeFinderBatchResult, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is None: + return + widget.set_progress( + result.completed_count, max(result.completed_count, 1) + ) + widget.set_status( + f"Complete: {result.completed_count} selected" + + ( + "" + if result.failed_count == 0 + else f", {result.failed_count} failed" + ) + ) + + def _on_item_failed(self, item_id: str, message: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status("Failed") + self._append_log(message) + + def _on_queue_finished(self, results: object) -> None: + self._set_running(False) + result_count = len(results) if isinstance(results, list) else 0 + self.queue_status_label.setText( + f"Queue finished: {result_count} project(s) processed" + ) + self.statusBar().showMessage( + "Representative structure batch queue finished" + ) + + def _on_queue_failed(self, item_id: str, message: str) -> None: + self._set_running(False) + self.queue_status_label.setText("Queue stopped after a failure") + self.statusBar().showMessage( + "Representative structure batch queue failed", + 5000, + ) + QMessageBox.warning( + self, + "Representative structure batch queue failed", + f"Queue item {item_id} failed:\n{message}", + ) + + def _cleanup_run_thread(self) -> None: + self._run_thread = None + self._run_worker = None + + +def launch_representativefinder_batch_queue_ui( + initial_project_dir: str | Path | None = None, + *, + initial_clusters_dir: str | Path | None = None, + initial_input_path: str | Path | None = None, +) -> int: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication([]) + configure_saxshell_application(app) + window = RepresentativeFinderBatchQueueWindow( + initial_project_dir=initial_project_dir, + initial_clusters_dir=initial_clusters_dir, + initial_input_path=initial_input_path, + ) + window.show() + return int(app.exec()) + + +__all__ = [ + "RepresentativeFinderBatchItem", + "RepresentativeFinderBatchItemWidget", + "RepresentativeFinderBatchJob", + "RepresentativeFinderBatchQueueWindow", + "RepresentativeFinderBatchResult", + "RepresentativeFinderBatchWorker", + "launch_representativefinder_batch_queue_ui", +] diff --git a/src/saxshell/representativefinder/workflow.py b/src/saxshell/representativefinder/workflow.py index a42f423..abb447a 100644 --- a/src/saxshell/representativefinder/workflow.py +++ b/src/saxshell/representativefinder/workflow.py @@ -24,7 +24,7 @@ describe_parsed_contrast_structure, estimate_pair_contact_distance_medians, ) -from saxshell.saxs.debye import load_structure_file +from saxshell.saxs.debye import load_structure_file, scan_structure_elements from saxshell.saxs.stoichiometry import parse_stoich_label _STRUCTURE_SUFFIXES = {".pdb", ".xyz"} @@ -1039,19 +1039,37 @@ def analyze_representative_structure_folder( else Path(project_dir).expanduser().resolve() ) - measured_structures, skipped_files, processed_work = ( - _measure_candidate_entries( - candidates_to_measure, - analyzer=analyzer, - include_parsed_structure=settings.solvent_weight > 0.0, - parallel_workers=parallel_workers, - progress_callback=progress_callback, - log_callback=log_callback, - cancel_callback=cancel_callback, - processed_work=processed_work, - total_work=total_work, - ) + single_atom_structures = _inspect_single_atom_candidate_entries( + candidates_to_measure, + progress_callback=progress_callback, + cancel_callback=cancel_callback, + processed_work=processed_work, + total_work=total_work, ) + if single_atom_structures is None: + measured_structures, skipped_files, processed_work = ( + _measure_candidate_entries( + candidates_to_measure, + analyzer=analyzer, + include_parsed_structure=settings.solvent_weight > 0.0, + parallel_workers=parallel_workers, + progress_callback=progress_callback, + log_callback=log_callback, + cancel_callback=cancel_callback, + processed_work=processed_work, + total_work=total_work, + ) + ) + else: + measured_structures = single_atom_structures + skipped_files = [] + processed_work += len(single_atom_structures) + _emit_progress( + progress_callback, + processed_work, + total_work, + "Inspected single-atom candidate structure files.", + ) measured_candidates = [ measured.candidate for measured in measured_structures ] @@ -1608,6 +1626,81 @@ def _effective_parallel_workers( return max(1, min(int(item_count), requested, 32)) +def _inspect_single_atom_candidate_entries( + entries: tuple[RepresentativeFinderFolderCandidate, ...], + *, + progress_callback: RepresentativeFinderProgressCallback | None, + cancel_callback: RepresentativeFinderCancelCallback | None, + processed_work: int, + total_work: int, +) -> list[_MeasuredCandidateStructure] | None: + if not entries: + return None + _emit_progress( + progress_callback, + processed_work, + total_work, + "Inspecting candidate atom counts...", + ) + scanned_elements_by_index: dict[int, tuple[str, ...]] = {} + element_signatures: set[tuple[tuple[str, int], ...]] = set() + for index, entry in enumerate(entries): + _raise_if_cancelled(cancel_callback) + try: + elements = tuple( + str(element).strip() + for element in scan_structure_elements(entry.file_path) + if str(element).strip() + ) + except Exception: + return None + if len(elements) != 1: + return None + element_counts = Counter(elements) + element_signatures.add(tuple(sorted(element_counts.items()))) + if len(element_signatures) > 1: + return None + scanned_elements_by_index[index] = elements + + measured_structures: list[_MeasuredCandidateStructure] = [] + for index, entry in enumerate(entries): + _raise_if_cancelled(cancel_callback) + try: + coordinates, loaded_elements = load_structure_file(entry.file_path) + except Exception: + return None + elements = tuple(str(element).strip() for element in loaded_elements) + if len(elements) != 1: + return None + if elements != scanned_elements_by_index[index]: + return None + coordinates_array = np.asarray(coordinates, dtype=float) + element_counts = Counter(elements) + candidate = RepresentativeFinderCandidate( + file_path=entry.file_path, + relative_label=entry.relative_label, + motif_label=entry.motif_label, + atom_count=1, + element_counts=dict(sorted(element_counts.items())), + bond_values={}, + angle_values={}, + solvent_metrics={}, + solvent_atom_count=0, + direct_solvent_atom_count=0, + outer_solvent_atom_count=0, + mean_direct_solvent_coordination=0.0, + ) + measured_structures.append( + _MeasuredCandidateStructure( + candidate=candidate, + coordinates=coordinates_array, + elements=elements, + parsed_structure=None, + ) + ) + return measured_structures + + def _measure_candidate_entries( entries: tuple[RepresentativeFinderFolderCandidate, ...], *, diff --git a/src/saxshell/saxs/_model_templates/__init__.py b/src/saxshell/saxs/_model_templates/__init__.py index 6c1f999..2fc7161 100644 --- a/src/saxshell/saxs/_model_templates/__init__.py +++ b/src/saxshell/saxs/_model_templates/__init__.py @@ -163,7 +163,11 @@ def list_template_specs( ): if path.stem in seen_names: continue - specs.append(load_template_spec(path.stem, resolved_dir)) + spec = load_template_spec(path.stem, resolved_dir) + if spec.deprecated and not include_deprecated: + seen_names.add(path.stem) + continue + specs.append(spec) seen_names.add(path.stem) return specs @@ -190,7 +194,10 @@ def load_template_spec( name=template_name, module_path=module_path, metadata_path=metadata_path if metadata_path.is_file() else None, - deprecated=module_path.parent.name == "_deprecated", + deprecated=( + module_path.parent.name == "_deprecated" + or bool(metadata["deprecated"]) + ), display_name=str(metadata["display_name"]), description=str(metadata["description"]), lmfit_model_name=directives["model_lmfit"], @@ -326,6 +333,7 @@ def _load_template_metadata( "the _model_templates folder to provide a friendly display " "name and a detailed description." ), + "deprecated": False, "cluster_geometry_support": TemplateClusterGeometrySupport( supported=False ), @@ -361,6 +369,7 @@ def _load_template_metadata( return { "display_name": display_name, "description": description, + "deprecated": bool(payload.get("deprecated", False)), "cluster_geometry_support": cluster_geometry_support, "solution_scattering_support": solution_scattering_support, "prefit_support": prefit_support, diff --git a/src/saxshell/saxs/_model_templates/template_pydream_charged_monosq_normalized_scaled_solvent.json b/src/saxshell/saxs/_model_templates/template_pydream_charged_monosq_normalized_scaled_solvent.json new file mode 100644 index 0000000..544a8b1 --- /dev/null +++ b/src/saxshell/saxs/_model_templates/template_pydream_charged_monosq_normalized_scaled_solvent.json @@ -0,0 +1,18 @@ +{ + "display_name": "pyDREAM Charged MonoSQ Normalized (Scaled Solvent Weight)", + "description": "pyDREAM Charged MonoSQ Normalized (Scaled Solvent Weight)\n\nPurpose:\nMonodisperse charged hard-sphere SAXS template for cluster libraries whose intercluster correlations are better represented by screened Coulomb repulsion than by neutral hard-sphere packing alone. It keeps the scaled-solvent convention from the scaled-solvent MonoSQ templates: the MD-derived solute branch and the weighted solvent trace are combined first, and the global model scale and offset are applied only to the assembled model curve.\n\nStructure Factor:\nThe solute branch uses the Hayter-Penfold rescaled mean spherical approximation (RMSA) structure factor for charged spheres, following the SasView hayter_msa model. The effective charged-sphere radius is eff_r, the particle volume fraction is vol_frac, charge is the sphere charge in elementary-charge units, temperature is in kelvin, concentration_salt is the molar concentration of added 1:1 electrolyte, and dielectconst is the solvent relative dielectric constant. The calculation derives the Debye screening length from the fitted charge, volume fraction, salt concentration, temperature, and dielectric constant.\n\nForm Factor:\nThe form-factor side remains the SAXSShell cluster-trace mixture. Each MD-derived averaged cluster profile I_i(q) is treated as an individual component trace, the raw weights w_i form a linear mixture, and the monodisperse charged-sphere S_RMSA(q) modulates that combined solute trace.\n\nModel Equation:\nI_mix(q) = sum_i w_i I_i(q)\nI_solute(q) = I_mix(q) S_RMSA(q; eff_r, vol_frac, charge, temperature, concentration_salt, dielectconst)\nI_raw(q) = I_solute(q) + solv_w * I_solv(q)\nI_model(q) = scale * I_raw(q) + offset\n\nCalculator Integration:\nThe solution-scattering estimator can pre-populate vol_frac from the physical solute-associated volume fraction and solv_w from the combined solvent-background multiplier. Because the solvent contribution lives inside I_raw(q), this template marks the solvent branch as globally scaled.\n\nLikelihood Convention:\nThe pyDREAM likelihood compares the unmodified experimental intensity values directly against I_model(q), with a point-normalized Gaussian log-likelihood and fixed sigma of 1e-4.\n\nNumerical Scope:\nThis is a charged-sphere template. charge is constrained to be positive because the Hayter-MSA calculation is not intended for neutral particles; use a neutral MonoSQ hard-sphere template when charge is effectively zero. The charge upper bound follows the SasView stability guidance of 200 e.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile.\nsolv_w: Solvent-background multiplier applied to the solvent trace before the global scale is applied; constrained to [0, 1] and fixed by default for calculator-seeded workflows.\noffset: Additive model baseline applied after the global model scale.\neff_r: Effective charged-sphere radius in Angstrom used by the RMSA structure factor.\nvol_frac: Charged-sphere particle volume fraction used by the RMSA structure factor; the solution-scattering estimator targets this field with the physical solute-associated volume fraction.\ncharge: Charged-sphere charge in elementary-charge units; constrained to (0, 200].\ntemperature: Absolute temperature in kelvin for Debye screening.\nconcentration_salt: Added 1:1 electrolyte concentration in mol/L.\ndielectconst: Relative dielectric constant of the solvent.\nscale: Multiplicative factor applied to the raw solute + weighted-solvent model curve.", + "capabilities": { + "solution_scattering_estimator": { + "volume_fraction_target": { + "parameter": "vol_frac", + "fraction_kind": "solute", + "source": "physical" + }, + "solvent_contribution_scale_mode": "global_scale" + }, + "prefit": { + "auto_apply_autoscale_on_load": true, + "autoscale_bounds_mode": "adaptive" + } + } +} diff --git a/src/saxshell/saxs/_model_templates/template_pydream_charged_monosq_normalized_scaled_solvent.py b/src/saxshell/saxs/_model_templates/template_pydream_charged_monosq_normalized_scaled_solvent.py new file mode 100644 index 0000000..a301bd0 --- /dev/null +++ b/src/saxshell/saxs/_model_templates/template_pydream_charged_monosq_normalized_scaled_solvent.py @@ -0,0 +1,738 @@ +import math + +import numpy as np +from scipy.stats import norm + +# ============================================================= +# model_lmfit: lmfit_model_profile +# model_pydream: log_likelihood_charged_monosq_scaled_solvent +# inputs_lmfit: q, solvent_data, model_data, params +# inputs_pydream: q, solvent_data, model_data, params +# param_columns: Structure, Motif, Param, Value, Vary, Min, Max +# +# param: solv_w,1.0,False,0.0,1.0 +# param: offset,0,True,-20,30 +# param: eff_r,20.75,True,1.0,200.0 +# param: vol_frac,0.0192,False,1e-6,0.5 +# param: charge,19.0,True,1e-6,200.0 +# param: temperature,298.0,False,1.0,450.0 +# param: concentration_salt,0.0,False,0.0,5.0 +# param: dielectconst,78.0,False,1.0,200.0 +# param: scale,5e-4,True,1e-8,5e-3 +# +# Charged MonoSQ normalized, scaled-solvent variant: +# I_raw(q) = sum_i w_i I_i(q) S_RMSA(q) + solv_w * I_solvent(q) +# I_model(q) = scale * I_raw(q) + offset +# +# The Hayter-Penfold RMSA S(q) implementation below follows the SasView +# sasmodels hayter_msa structure-factor kernel: +# https://www.sasview.org/docs/user/models/hayter_msa.html +# ============================================================= + +_ELEMENTARY_CHARGE_C = 1.602189e-19 +_BOLTZMANN_J_PER_K = 1.380662e-23 +_VACUUM_PERMITTIVITY = 8.85418782e-12 +_AVOGADRO = 6.022e23 + + +def _validate_hayter_inputs( + radius_effective, + volfraction, + charge, + temperature, + concentration_salt, + dielectconst, +): + radius_effective = float(radius_effective) + volfraction = float(volfraction) + charge = float(charge) + temperature = float(temperature) + concentration_salt = float(concentration_salt) + dielectconst = float(dielectconst) + + if radius_effective <= 0.0: + raise ValueError("eff_r must be positive") + if not (0.0 < volfraction < 0.74): + raise ValueError("vol_frac must satisfy 0 < vol_frac < 0.74") + if not (0.0 < charge <= 200.0): + raise ValueError("charge must satisfy 0 < charge <= 200") + if temperature <= 0.0: + raise ValueError("temperature must be positive") + if concentration_salt < 0.0: + raise ValueError("concentration_salt must be non-negative") + if dielectconst <= 0.0: + raise ValueError("dielectconst must be positive") + + return ( + radius_effective, + volfraction, + charge, + temperature, + concentration_salt, + dielectconst, + ) + + +def calc_hayter_msa_sq( + radius_effective, + volfraction, + charge, + temperature, + concentration_salt, + dielectconst, + q_values, +): + """Return the Hayter-Penfold RMSA charged-sphere structure + factor.""" + ( + radius_effective, + volfraction, + charge, + temperature, + concentration_salt, + dielectconst, + ) = _validate_hayter_inputs( + radius_effective, + volfraction, + charge, + temperature, + concentration_salt, + dielectconst, + ) + q_values = np.asarray(q_values, dtype=float) + return np.asarray( + [ + _hayter_msa_iq( + float(q_value), + radius_effective, + volfraction, + charge, + temperature, + concentration_salt, + dielectconst, + ) + for q_value in q_values + ], + dtype=float, + ) + + +def _hayter_msa_iq( + q_value, + radius_effective, + volfraction, + charge_number, + temperature, + concentration_salt, + dielectconst, +): + g = [float(value) for value in range(1, 18)] + + diameter_angstrom = 2.0 * radius_effective + beta = 1.0 / (_BOLTZMANN_J_PER_K * temperature) + permittivity = dielectconst * _VACUUM_PERMITTIVITY + charge_coulomb = charge_number * _ELEMENTARY_CHARGE_C + diameter_m = diameter_angstrom * 1.0e-10 + particle_volume = (4.0 * math.pi / 3.0) * (diameter_m / 2.0) ** 3 + salt_number_density = concentration_salt * _AVOGADRO * 1.0e3 + + ionic_strength = ( + 0.5 + * _ELEMENTARY_CHARGE_C + * _ELEMENTARY_CHARGE_C + * (charge_number * volfraction / particle_volume) + ) + ionic_strength += ( + 0.5 + * _ELEMENTARY_CHARGE_C + * _ELEMENTARY_CHARGE_C + * (2.0 * salt_number_density) + ) + kappa = math.sqrt(2.0 * beta * ionic_strength / permittivity) + + kappa_diameter = kappa * diameter_m + g[5] = ( + beta + * charge_coulomb + * charge_coulomb + / (math.pi * permittivity * diameter_m * (2.0 + kappa_diameter) ** 2) + ) + g[6] = kappa_diameter + g[4] = volfraction + + ss = g[4] ** (1.0 / 3.0) + g[9] = 2.0 * ss * g[5] * math.exp(g[6] - g[6] / ss) + + ierr = _sqcoef(0, g) + if ierr < 0: + return math.nan + return _sqhcal(q_value * diameter_angstrom, g) + + +def _sqcoef(ir, g): + max_iterations = 40 + accuracy = 5.0e-6 + f1 = 0.0 + f2 = 0.0 + + ig = 1 + if g[6] >= (1.0 + 8.0 * g[4]): + ig = 0 + g[15] = g[14] + g[16] = g[4] + ir = _sqfun(1, ir, g) + g[14] = g[15] + g[4] = g[16] + if ir < 0 or g[14] >= 0.0: + return ir + + g[10] = min(g[4], 0.20) + if ig != 1 or g[9] >= 0.15: + ii = 0 + while True: + ii += 1 + if ii > max_iterations: + return -1 + if g[10] <= 0.0: + g[10] = g[4] / ii + if g[10] > 0.6: + g[10] = 0.35 / ii + + e1 = g[10] + g[15] = f1 + g[16] = e1 + ir = _sqfun(2, ir, g) + if ir < 0: + return ir + f1 = g[15] + e1 = g[16] + + e2 = g[10] * 1.01 + g[15] = f2 + g[16] = e2 + ir = _sqfun(2, ir, g) + if ir < 0: + return ir + f2 = g[15] + e2 = g[16] + + denominator = f2 - f1 + if denominator == 0.0: + return -1 + e2 = e1 - (e2 - e1) * f1 / denominator + g[10] = e2 + delta = abs((e2 - e1) / e1) if e1 != 0.0 else abs(e2 - e1) + if delta <= accuracy: + break + + g[15] = g[14] + g[16] = e2 + ir = _sqfun(4, ir, g) + if ir < 0: + return ir + g[14] = g[15] + ir = ii + if ig != 1 or g[10] >= g[4]: + return ir + + g[15] = g[14] + g[16] = g[4] + ir = _sqfun(3, ir, g) + g[14] = g[15] + g[4] = g[16] + if ir >= 0 and g[14] < 0.0: + ir = -3 + return ir + + +def _sqfun(ix, ir, g): + accuracy = 1.0e-6 + max_iterations = 40 + + a2 = a3 = b2 = b3 = v2 = v3 = p2 = p3 = 0.0 + + reta = g[16] + eta2 = reta * reta + eta3 = eta2 * reta + e12 = 12.0 * reta + e24 = e12 + e12 + g[13] = (g[4] / g[16]) ** (1.0 / 3.0) + g[12] = g[6] / g[13] + ibig = 1 if (g[12] > 15.0 and ix == 1) else 0 + + g[11] = g[5] * g[13] * math.exp(g[6] - g[12]) + rgek = g[11] + rak = g[12] + ak2 = rak * rak + ak1 = 1.0 + rak + dak2 = 1.0 / ak2 + dak4 = dak2 * dak2 + d = 1.0 - reta + d2 = d * d + dak = d / rak + dd2 = 1.0 / d2 + dd4 = dd2 * dd2 + dd45 = dd4 * 2.0e-1 + eta3d = 3.0 * reta + eta6d = eta3d + eta3d + eta32 = eta3 + eta3 + eta2d = reta + 2.0 + eta2d2 = eta2d * eta2d + eta21 = 2.0 * reta + 1.0 + eta22 = eta21 * eta21 + + al1 = -eta21 * dak + al2 = (14.0 * eta2 - 4.0 * reta - 1.0) * dak2 + al3 = 36.0 * eta2 * dak4 + + be1 = -(eta2 + 7.0 * reta + 1.0) * dak + be2 = 9.0 * reta * (eta2 + 4.0 * reta - 2.0) * dak2 + be3 = 12.0 * reta * (2.0 * eta2 + 8.0 * reta - 1.0) * dak4 + + vu1 = -(eta3 + 3.0 * eta2 + 45.0 * reta + 5.0) * dak + vu2 = (eta32 + 3.0 * eta2 + 42.0 * reta - 20.0) * dak2 + vu3 = (eta32 + 30.0 * reta - 5.0) * dak4 + vu4 = vu1 + e24 * rak * vu3 + vu5 = eta6d * (vu2 + 4.0 * vu3) + + ph1 = eta6d / rak + ph2 = d - e12 * dak2 + + ta1 = (reta + 5.0) / (5.0 * rak) + ta2 = eta2d * dak2 + ta3 = -e12 * rgek * (ta1 + ta2) + ta4 = eta3d * ak2 * (ta1 * ta1 - ta2 * ta2) + ta5 = eta3d * (reta + 8.0) * 1.0e-1 - 2.0 * eta22 * dak2 + + ex1 = math.exp(rak) + ex2 = math.exp(-rak) if g[12] < 20.0 else 0.0 + sk = 0.5 * (ex1 - ex2) + ck = 0.5 * (ex1 + ex2) + ckma = ck - 1.0 - rak * sk + skma = sk - rak * ck + + a1 = (e24 * rgek * (al1 + al2 + ak1 * al3) - eta22) * dd4 + if ibig == 0: + a2 = e24 * (al3 * skma + al2 * sk - al1 * ck) * dd4 + a3 = ( + e24 + * (eta22 * dak2 - 0.5 * d2 + al3 * ckma - al1 * sk + al2 * ck) + * dd4 + ) + + b1 = (1.5 * reta * eta2d2 - e12 * rgek * (be1 + be2 + ak1 * be3)) * dd4 + if ibig == 0: + b2 = e12 * (-be3 * skma - be2 * sk + be1 * ck) * dd4 + b3 = ( + e12 + * ( + 0.5 * d2 * eta2d + - eta3d * eta2d2 * dak2 + - be3 * ckma + + be1 * sk + - be2 * ck + ) + * dd4 + ) + + v1 = ( + eta21 * (eta2 - 2.0 * reta + 10.0) * 2.5e-1 - rgek * (vu4 + vu5) + ) * dd45 + if ibig == 0: + v2 = (vu4 * ck - vu5 * sk) * dd45 + v3 = ( + (eta3 - 6.0 * eta2 + 5.0) * d + - eta6d * (2.0 * eta3 - 3.0 * eta2 + 18.0 * reta + 10.0) * dak2 + + e24 * vu3 + + vu4 * sk + - vu5 * ck + ) * dd45 + + pp1 = ph1 * ph1 + pp2 = ph2 * ph2 + pp = pp1 + pp2 + p1p2 = ph1 * ph2 * 2.0 + p1 = (rgek * (pp1 + pp2 - p1p2) - 0.5 * eta2d) * dd2 + if ibig == 0: + p2 = (pp * sk + p1p2 * ck) * dd2 + p3 = (pp * ck + p1p2 * sk + pp1 - pp2) * dd2 + + t1 = ta3 + ta4 * a1 + ta5 * b1 + if ibig != 0: + v3 = ( + (eta3 - 6.0 * eta2 + 5.0) * d + - eta6d * (2.0 * eta3 - 3.0 * eta2 + 18.0 * reta + 10.0) * dak2 + + e24 * vu3 + ) * dd45 + t3 = ta4 * a3 + ta5 * b3 + e12 * ta2 + t3 += -4.0e-1 * reta * (reta + 10.0) - 1.0 + p3 = (pp1 - pp2) * dd2 + b3 = e12 * (0.5 * d2 * eta2d - eta3d * eta2d2 * dak2 + be3) * dd4 + a3 = e24 * (eta22 * dak2 - 0.5 * d2 - al3) * dd4 + um6 = t3 * a3 - e12 * v3 * v3 + um5 = t1 * a3 + a1 * t3 - e24 * v1 * v3 + um4 = t1 * a1 - e12 * v1 * v1 + lam6 = e12 * p3 * p3 + lam5 = e24 * p1 * p3 - b3 - b3 - ak2 + lam4 = e12 * p1 * p1 - b1 - b1 + w56 = um5 * lam6 - lam5 * um6 + w46 = um4 * lam6 - lam4 * um6 + fa = -w46 / w56 + ca = -fa + g[3] = fa + g[2] = ca + g[1] = b1 + b3 * fa + g[0] = a1 + a3 * fa + g[8] = v1 + v3 * fa + g[14] = -(p1 + p3 * fa) + g[15] = 0.0 if abs(g[14]) < 1.0e-3 else g[14] + g[10] = g[16] + else: + t2 = ta4 * a2 + ta5 * b2 + e12 * (ta1 * ck - ta2 * sk) + t3 = ta4 * a3 + ta5 * b3 + t3 += e12 * (ta1 * sk - ta2 * (ck - 1.0)) + t3 += -4.0e-1 * reta * (reta + 10.0) - 1.0 + + um1 = t2 * a2 - e12 * v2 * v2 + um2 = t1 * a2 + t2 * a1 - e24 * v1 * v2 + um3 = t2 * a3 + t3 * a2 - e24 * v2 * v3 + um4 = t1 * a1 - e12 * v1 * v1 + um5 = t1 * a3 + t3 * a1 - e24 * v1 * v3 + um6 = t3 * a3 - e12 * v3 * v3 + + if ix in {1, 3}: + lam1 = e12 * p2 * p2 + lam2 = e24 * p1 * p2 - b2 - b2 + lam3 = e24 * p2 * p3 + lam4 = e12 * p1 * p1 - b1 - b1 + lam5 = e24 * p1 * p3 - b3 - b3 - ak2 + lam6 = e12 * p3 * p3 + + w16 = um1 * lam6 - lam1 * um6 + w15 = um1 * lam5 - lam1 * um5 + w14 = um1 * lam4 - lam1 * um4 + w13 = um1 * lam3 - lam1 * um3 + w12 = um1 * lam2 - lam1 * um2 + w26 = um2 * lam6 - lam2 * um6 + w25 = um2 * lam5 - lam2 * um5 + w24 = um2 * lam4 - lam2 * um4 + w36 = um3 * lam6 - lam3 * um6 + w35 = um3 * lam5 - lam3 * um5 + w34 = um3 * lam4 - lam3 * um4 + w32 = um3 * lam2 - lam3 * um2 + w46 = um4 * lam6 - lam4 * um6 + w56 = um5 * lam6 - lam5 * um6 + w3526 = w35 + w26 + w3425 = w34 + w25 + + w4 = w16 * w16 - w13 * w36 + w3 = 2.0 * w16 * w15 - w13 * w3526 - w12 * w36 + w2 = w15 * w15 + 2.0 * w16 * w14 - w13 * w3425 - w12 * w3526 + w1 = 2.0 * w15 * w14 - w13 * w24 - w12 * w3425 + w0 = w14 * w14 - w12 * w24 + + if ix == 1: + fap = (w14 - w34 - w46) / (w12 - w15 + w35 - w26 + w56 - w32) + else: + g[14] = 0.5 * eta2d * dd2 * math.exp(-rgek) + if 0.0 <= g[11] <= 2.0 and g[12] <= 1.0: + e24g = e24 * rgek * math.exp(rak) + pwk = math.sqrt(e24g) + qpw = ( + (1.0 - math.sqrt(1.0 + 2.0 * d2 * d * pwk / eta22)) + * eta21 + / d + ) + g[14] = -qpw * qpw / e24 + 0.5 * eta2d * dd2 + pg = p1 + g[14] + ca = ak2 * pg + 2.0 * (b3 * pg - b1 * p3) + ca += e12 * g[14] * g[14] * p3 + ca = -ca / (ak2 * p2 + 2.0 * (b3 * p2 - b2 * p3)) + fap = -(pg + p2 * ca) / p3 + + ii = 0 + while True: + ii += 1 + if ii > max_iterations: + return -2 + fa = fap + fun = w0 + (w1 + (w2 + (w3 + w4 * fa) * fa) * fa) * fa + fund = w1 + (2.0 * w2 + (3.0 * w3 + 4.0 * w4 * fa) * fa) * fa + fap = fa - fun / fund + delta = abs((fap - fa) / fa) if fa != 0.0 else abs(fap - fa) + if delta <= accuracy: + break + + ir += ii + fa = fap + ca = -(w16 * fa * fa + w15 * fa + w14) / (w13 * fa + w12) + g[14] = -(p1 + p2 * ca + p3 * fa) + g[15] = 0.0 if abs(g[14]) < 1.0e-3 else g[14] + g[10] = g[16] + else: + ca = ak2 * p1 + 2.0 * (b3 * p1 - b1 * p3) + ca = -ca / (ak2 * p2 + 2.0 * (b3 * p2 - b2 * p3)) + fa = -(p1 + p2 * ca) / p3 + if ix == 2: + g[15] = ( + um1 * ca * ca + + (um2 + um3 * fa) * ca + + um4 + + um5 * fa + + um6 * fa * fa + ) + if ix == 4: + g[15] = -(p1 + p2 * ca + p3 * fa) + + g[3] = fa + g[2] = ca + g[1] = b1 + b2 * ca + b3 * fa + g[0] = a1 + a2 * ca + a3 * fa + g[8] = (v1 + v2 * ca + v3 * fa) / g[0] + + g24 = e24 * rgek * ex1 + g[7] = (rak * ak2 * g[2] - g24) / (ak2 * g24) + return ir + + +def _sqhcal(qq, g): + etaz = g[10] + akz = g[12] + gekz = g[11] + e24 = 24.0 * etaz + x1 = math.exp(akz) + x2 = math.exp(-akz) if g[12] < 20.0 else 0.0 + ck = 0.5 * (x1 + x2) + sk = 0.5 * (x1 - x2) + ak2 = akz * akz + + qk = qq / g[13] + q2k = qk * qk + if qk <= 1.0e-8: + return -1.0 / g[0] + if qk <= 0.01: + aqk = g[0] * (8.0 + 2.0 * etaz) + 6.0 * g[1] - 12.0 * g[3] + aqk -= ( + 24.0 + * ( + gekz * (1.0 + akz) + - ck * akz * g[2] + + g[3] * (ck - 1.0) + + (g[2] - g[3] * akz) * sk + ) + / ak2 + ) + aqk += q2k * ( + -((g[0] * (48.0 + 15.0 * etaz) + 40.0 * g[1]) / 60.0) + + g[3] + + (4.0 / ak2) + * ( + gekz * (9.0 + 7.0 * akz) + + ck * (9.0 * g[3] - 7.0 * g[2] * akz) + + sk * (9.0 * g[2] - 7.0 * g[3] * akz) + ) + ) + return 1.0 / (1.0 - g[10] * aqk) + + qk2 = 1.0 / q2k + qk3 = qk2 / qk + qqk = 1.0 / (qk * (q2k + ak2)) + sink = math.sin(qk) + cosk = math.cos(qk) + asink = akz * sink + qcosk = qk * cosk + + aqk = g[0] * (sink - qcosk) + aqk += g[1] * ((2.0 * qk2 - 1.0) * qcosk + 2.0 * sink - 2.0 / qk) + inter = 24.0 * qk3 + 4.0 * (1.0 - 6.0 * qk2) * sink + aqk += ( + 0.5 + * etaz + * g[0] + * (inter - (1.0 - 12.0 * qk2 + 24.0 * qk2 * qk2) * qcosk) + ) + aqk *= qk3 + aqk += g[2] * (ck * asink - sk * qcosk) * qqk + aqk += g[3] * (sk * asink - qk * (ck * cosk - 1.0)) * qqk + aqk += g[3] * (cosk - 1.0) * qk2 + aqk -= gekz * (asink + qcosk) * qqk + return 1.0 / (1.0 - e24 * aqk) + + +def _bounded_solvent_weight(value): + return float(np.clip(float(value), 0.0, 1.0)) + + +def _weight_keys_from_params(params): + return sorted( + (key for key in params if key.startswith("w") and key[1:].isdigit()), + key=lambda key: int(key[1:]), + ) + + +def structure_factor_profile(q, solvent_data, model_data, **params): + """Return the pure charged hard-sphere RMSA structure-factor + trace.""" + del solvent_data, model_data + return calc_hayter_msa_sq( + params["eff_r"], + params["vol_frac"], + params["charge"], + params["temperature"], + params["concentration_salt"], + params["dielectconst"], + np.asarray(q, dtype=float), + ) + + +def raw_charged_monosq_scaled_solvent_profile( + q_values, + solvent_intensities, + component_intensities, + weights, + solv_w, + eff_r, + vol_frac, + charge, + temperature, + concentration_salt, + dielectconst, +): + """Return the unscaled charged-S(Q) solute plus weighted solvent.""" + q_values = np.asarray(q_values, dtype=float) + mixture = np.zeros_like(q_values, dtype=float) + for weight, component in zip(weights, component_intensities): + mixture += float(weight) * np.asarray(component, dtype=float) + + structure_factor = calc_hayter_msa_sq( + eff_r, + vol_frac, + charge, + temperature, + concentration_salt, + dielectconst, + q_values, + ) + solvent_contribution = _bounded_solvent_weight(solv_w) * np.asarray( + solvent_intensities, + dtype=float, + ) + return mixture * structure_factor + solvent_contribution + + +def charged_monosq_scaled_solvent_profile( + q_values, + solvent_intensities, + component_intensities, + weights, + solv_w, + eff_r, + vol_frac, + charge, + temperature, + concentration_salt, + dielectconst, + scale, + offset, +): + """Apply the global scale and offset to the charged MonoSQ model.""" + raw_model = raw_charged_monosq_scaled_solvent_profile( + q_values, + solvent_intensities, + component_intensities, + weights, + solv_w, + eff_r, + vol_frac, + charge, + temperature, + concentration_salt, + dielectconst, + ) + return float(scale) * raw_model + float(offset) + + +def lmfit_model_profile(q, solvent_data, model_data, **params): + """Evaluate the charged scaled-solvent MonoSQ SAXS model for + lmfit.""" + weight_keys = _weight_keys_from_params(params) + weights = [params[key] for key in weight_keys] + + return charged_monosq_scaled_solvent_profile( + q, + solvent_data, + model_data, + weights, + params["solv_w"], + params["eff_r"], + params["vol_frac"], + params["charge"], + params["temperature"], + params["concentration_salt"], + params["dielectconst"], + params["scale"], + params["offset"], + ) + + +def model_charged_monosq_scaled_solvent(params): + """Return the forward model intensity for pyDREAM.""" + global q_values + global theoretical_intensities + global solvent_intensities + + n_profiles = len(theoretical_intensities) + + weights = params[:n_profiles] + solv_w = params[n_profiles] + offset = params[n_profiles + 1] + eff_r = params[n_profiles + 2] + vol_frac = params[n_profiles + 3] + charge = params[n_profiles + 4] + temperature = params[n_profiles + 5] + concentration_salt = params[n_profiles + 6] + dielectconst = params[n_profiles + 7] + scale = params[n_profiles + 8] + + return charged_monosq_scaled_solvent_profile( + q_values, + solvent_intensities, + theoretical_intensities, + weights, + solv_w, + eff_r, + vol_frac, + charge, + temperature, + concentration_salt, + dielectconst, + scale, + offset, + ) + + +def log_likelihood_charged_monosq_scaled_solvent(params): + """Return the normalized Gaussian log-likelihood for pyDREAM.""" + global experimental_intensities + + try: + model_intensity = model_charged_monosq_scaled_solvent(params) + except (OverflowError, ValueError, FloatingPointError): + return -np.inf + if not np.all(np.isfinite(model_intensity)): + return -np.inf + + experimental = np.asarray(experimental_intensities, dtype=float) + n_points = len(experimental) + log_likelihood = np.sum( + norm.logpdf( + experimental, + loc=model_intensity, + scale=1e-4, + ) + ) + + if n_points == 0: + return log_likelihood + + return log_likelihood / n_points diff --git a/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent_model_scale.json b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent_model_scale.json new file mode 100644 index 0000000..e4f659e --- /dev/null +++ b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent_model_scale.json @@ -0,0 +1,18 @@ +{ + "display_name": "pyDREAM MonoSQ Normalized (Scaled Solvent Weight, Model Scale/Offset)", + "description": "pyDREAM MonoSQ Normalized (Scaled Solvent Weight, Model Scale/Offset)\n\nPurpose:\nCorrected scaled-solvent MonoSQ variant for workflows that need the global scale and additive offset to be applied explicitly to the model curve rather than to the experimental SAXS trace. Existing MonoSQ templates remain unchanged for project compatibility.\n\nStructure Factor:\nMonodisperse hard-sphere structure factor evaluated with calc_monodisperse_sq using the Percus-Yevick approximation. eff_r controls the effective hard-sphere radius and vol_frac controls the packing term.\n\nForm Factor:\nWeighted mixture of MD-derived SAXS component profiles assembled from the project's averaged cluster scattering curves. The weighted mixture is multiplied by the hard-sphere structure factor and combined with solv_w times the solvent trace before the global model transform is applied.\n\nModel Equation:\nI_raw(q) = sum_i w_i I_i(q) S_HS(q; eff_r, vol_frac) + solv_w * I_solv(q)\nI_model(q) = scale * I_raw(q) + offset\n\nLikelihood Convention:\nThe pyDREAM likelihood compares the unmodified experimental intensity values directly against I_model(q), with a point-normalized Gaussian log-likelihood. scale and offset are applied only inside the forward model.\n\nCalculator Integration:\nThe solution-scattering estimator can pre-populate vol_frac from the physical solute-associated volume fraction and solv_w from the solvent-background multiplier. Since the solvent contribution is inside I_raw(q), the metadata marks the solvent branch as globally scaled.\n\nPrefit Startup:\nWhen this template is loaded with experimental data available, Prefit can apply its autoscale estimate to the scale and offset parameters unless a saved Best Prefit or current Prefit state already exists. Bounds are centered around the autoscale result.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile.\nsolv_w: Solvent-background multiplier applied to the solvent trace before the model scale is applied; constrained to [0, 1].\noffset: Additive model baseline applied after the model scale.\neff_r: Effective hard-sphere radius used in the structure-factor calculation.\nvol_frac: Hard-sphere packing fraction used in the Percus-Yevick structure factor.\nscale: Multiplicative factor applied to the raw model curve.", + "capabilities": { + "solution_scattering_estimator": { + "volume_fraction_target": { + "parameter": "vol_frac", + "fraction_kind": "solute", + "source": "physical" + }, + "solvent_contribution_scale_mode": "global_scale" + }, + "prefit": { + "auto_apply_autoscale_on_load": true, + "autoscale_bounds_mode": "adaptive" + } + } +} diff --git a/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent_model_scale.py b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent_model_scale.py new file mode 100644 index 0000000..1148ca9 --- /dev/null +++ b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent_model_scale.py @@ -0,0 +1,221 @@ +import numpy as np +from scipy.stats import norm + +# ============================================== +# model_lmfit: lmfit_model_profile +# model_pydream: log_likelihood_monosq_scaled_solvent_model_scale +# inputs_lmfit: q, solvent_data, model_data, params +# inputs_pydream: q, solvent_data, model_data, params +# param_columns: Structure, Motif, Param, Value, Vary, Min, Max +# +# param: solv_w,1.0,False,0.0,1.0 +# param: offset,0,True,-20,30 +# param: eff_r,3.0,True,3,20 +# param: vol_frac,0.0,False,0.0,0.5 +# param: scale,5e-4,True,1e-8,5e-3 +# +# MonoSQ normalized, scaled-solvent, model-scaled variant: +# I_raw(q) = I_solute(q) + solv_w * I_solvent(q) +# I_model(q) = scale * I_raw(q) + offset +# +# The experimental trace is compared directly against I_model(q); scale and +# offset are never applied to experimental_intensities in this template. +# Existing templates are intentionally left unchanged for compatibility. +# ============================================== + + +def calc_monodisperse_sq(r, vol_frac, q_values): + """Return the hard-sphere Percus-Yevick structure factor.""" + q_values = np.asarray(q_values, dtype=float) + r = float(r) + vol_frac = float(vol_frac) + a = 2.0 * q_values * r + a_safe = np.where(np.abs(a) < 1e-12, 1e-12, a) + + alpha = (1.0 + 2.0 * vol_frac) ** 2 / (1.0 - vol_frac) ** 4 + beta = ( + -6.0 * vol_frac * (1.0 + vol_frac / 2.0) ** 2 / (1.0 - vol_frac) ** 4 + ) + gamma = ( + 0.5 * vol_frac * (1.0 + 2.0 * vol_frac) ** 2 / (1.0 - vol_frac) ** 4 + ) + + g1 = alpha / a_safe**2 * (np.sin(a_safe) - a_safe * np.cos(a_safe)) + g2 = ( + beta + / a_safe**3 + * ( + 2.0 * a_safe * np.sin(a_safe) + + (2.0 - a_safe**2) * np.cos(a_safe) + - 2.0 + ) + ) + g3 = ( + gamma + / a_safe**5 + * ( + -(a_safe**4) * np.cos(a_safe) + + 4.0 + * ( + (3.0 * a_safe**2 - 6.0) * np.cos(a_safe) + + (a_safe**3 - 6.0 * a_safe) * np.sin(a_safe) + + 6.0 + ) + ) + ) + g = g1 + g2 + g3 + sq = 1.0 / (1.0 + 24.0 * vol_frac * (g / a_safe)) + + if np.any(np.abs(a) < 1e-12): + sq = np.asarray(sq, dtype=float) + sq[np.abs(a) < 1e-12] = (1.0 - vol_frac) ** 4 / ( + 1.0 + 2.0 * vol_frac + ) ** 2 + + return sq + + +def _bounded_solvent_weight(value): + return float(np.clip(float(value), 0.0, 1.0)) + + +def _weight_keys_from_params(params): + return sorted( + (key for key in params if key.startswith("w") and key[1:].isdigit()), + key=lambda key: int(key[1:]), + ) + + +def structure_factor_profile(q, solvent_data, model_data, **params): + """Return the pure hard-sphere structure-factor trace S(q).""" + del solvent_data, model_data + return calc_monodisperse_sq( + params["eff_r"], + params["vol_frac"], + np.asarray(q, dtype=float), + ) + + +def raw_monosq_scaled_solvent_profile( + q_values, + solvent_intensities, + component_intensities, + weights, + solv_w, + eff_r, + vol_frac, +): + """Return the unscaled solute-plus-weighted-solvent model branch.""" + q_values = np.asarray(q_values, dtype=float) + mixture = np.zeros_like(q_values, dtype=float) + for weight, component in zip(weights, component_intensities): + mixture += float(weight) * np.asarray(component, dtype=float) + + solute_intensity = mixture * calc_monodisperse_sq( + eff_r, + vol_frac, + q_values, + ) + solvent_contribution = _bounded_solvent_weight(solv_w) * np.asarray( + solvent_intensities, + dtype=float, + ) + return solute_intensity + solvent_contribution + + +def scaled_monosq_model_profile( + q_values, + solvent_intensities, + component_intensities, + weights, + solv_w, + eff_r, + vol_frac, + scale, + offset, +): + """Apply the fit transform to the model curve, not the data + curve.""" + raw_model = raw_monosq_scaled_solvent_profile( + q_values, + solvent_intensities, + component_intensities, + weights, + solv_w, + eff_r, + vol_frac, + ) + return float(scale) * raw_model + float(offset) + + +def lmfit_model_profile(q, solvent_data, model_data, **params): + """Evaluate the model-scaled MonoSQ SAXS model for lmfit.""" + weight_keys = _weight_keys_from_params(params) + weights = [params[key] for key in weight_keys] + + return scaled_monosq_model_profile( + q, + solvent_data, + model_data, + weights, + params["solv_w"], + params["eff_r"], + params["vol_frac"], + params["scale"], + params["offset"], + ) + + +def model_monosq_scaled_solvent_model_scale(params): + """Return the forward model intensity for pyDREAM.""" + global q_values + global theoretical_intensities + global solvent_intensities + + n_profiles = len(theoretical_intensities) + + weights = params[:n_profiles] + solv_w = params[n_profiles] + offset = params[n_profiles + 1] + eff_r = params[n_profiles + 2] + vol_frac = params[n_profiles + 3] + scale = params[n_profiles + 4] + + return scaled_monosq_model_profile( + q_values, + solvent_intensities, + theoretical_intensities, + weights, + solv_w, + eff_r, + vol_frac, + scale, + offset, + ) + + +def log_likelihood_monosq_scaled_solvent_model_scale(params): + """Return the normalized Gaussian log-likelihood for pyDREAM.""" + global experimental_intensities + + try: + model_intensity = model_monosq_scaled_solvent_model_scale(params) + except (ValueError, FloatingPointError): + return -np.inf + if not np.all(np.isfinite(model_intensity)): + return -np.inf + + experimental = np.asarray(experimental_intensities, dtype=float) + n_points = len(experimental) + log_likelihood = np.sum( + norm.logpdf( + experimental, + loc=model_intensity, + scale=1e-4, + ) + ) + + if n_points == 0: + return log_likelihood + + return log_likelihood / n_points diff --git a/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_legacy.json b/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_legacy.json index 2518ac9..b75597f 100644 --- a/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_legacy.json +++ b/src/saxshell/saxs/_model_templates/template_pydream_poly_lma_hs_legacy.json @@ -1,5 +1,6 @@ { "display_name": "pyDREAM Poly LMA Hard-Sphere (deprecated)", + "deprecated": true, "description": "pyDREAM Poly LMA Hard-Sphere (deprecated)\n\nThis legacy template preserves the earlier mixed sphere/ellipsoid-equivalent-sphere behavior for backwards compatibility with older projects.\n\nFor new work, prefer one of the newer split templates:\n- pyDREAM Poly LMA Hard-Sphere for the strict sphere-only hard-sphere workflow\n- pyDREAM Poly LMA Hard-Sphere/Ellipsoid Mix (Approx.) for the mixed-shape approximate workflow\n\nLike the mixed approximate template, ellipsoid rows are reduced to an equivalent-sphere interaction radius before the hard-sphere Percus-Yevick structure factor is evaluated.\n", "capabilities": { "cluster_geometry_metadata": { diff --git a/src/saxshell/saxs/ui/__init__.py b/src/saxshell/saxs/ui/__init__.py index 81ae6b1..6784c8a 100644 --- a/src/saxshell/saxs/ui/__init__.py +++ b/src/saxshell/saxs/ui/__init__.py @@ -3,9 +3,11 @@ __all__ = [ "DistributionSetupWindow", "ExperimentalDataHeaderDialog", + "ExperimentalDataOverlayWindow", "PriorHistogramWindow", "SAXSProgressDialog", "SAXSMainWindow", + "launch_experimental_data_overlay_ui", "launch_saxs_ui", ] @@ -19,6 +21,21 @@ def __getattr__(name: str): from .experimental_data_loader import ExperimentalDataHeaderDialog return ExperimentalDataHeaderDialog + if name in { + "ExperimentalDataOverlayWindow", + "launch_experimental_data_overlay_ui", + }: + from .experimental_overlay_window import ( + ExperimentalDataOverlayWindow, + launch_experimental_data_overlay_ui, + ) + + return { + "ExperimentalDataOverlayWindow": ExperimentalDataOverlayWindow, + "launch_experimental_data_overlay_ui": ( + launch_experimental_data_overlay_ui + ), + }[name] if name == "PriorHistogramWindow": from .prior_histogram_window import PriorHistogramWindow diff --git a/src/saxshell/saxs/ui/experimental_data_loader.py b/src/saxshell/saxs/ui/experimental_data_loader.py index 7477d60..38f0955 100644 --- a/src/saxshell/saxs/ui/experimental_data_loader.py +++ b/src/saxshell/saxs/ui/experimental_data_loader.py @@ -34,12 +34,22 @@ def __init__( file_path: str | Path, parent: QWidget | None = None, *, + title: str = "Check Experimental Data File", + independent_column_label: str = "q column", + dependent_column_label: str = "Intensity column", + error_column_label: str = "Error column", + intro_text: str | None = None, initial_header_rows: int | None = None, initial_q_column: int | None = None, initial_intensity_column: int | None = None, initial_error_column: int | None = None, ) -> None: super().__init__(parent) + self._dialog_title = title + self._independent_column_label = independent_column_label + self._dependent_column_label = dependent_column_label + self._error_column_label = error_column_label + self._intro_text = intro_text self.file_path = Path(file_path).expanduser().resolve() self._accepted_summary: ExperimentalDataSummary | None = None self._preview_lines = self._read_preview_lines() @@ -74,14 +84,18 @@ def error_column(self) -> int | None: return None if data is None else int(data) def _build_ui(self) -> None: - self.setWindowTitle("Check Experimental Data File") + self.setWindowTitle(self._dialog_title) self.resize(900, 720) root = QVBoxLayout(self) intro_label = QLabel( - "The selected file could not be parsed directly. Adjust the " - "number of header rows to skip, confirm which columns correspond " - "to q, intensity, and error, and then load the file again." + self._intro_text + or ( + "The selected file could not be parsed directly. Adjust the " + "number of header rows to skip, confirm which columns " + "correspond to q, intensity, and error, and then load the " + "file again." + ) ) intro_label.setWordWrap(True) root.addWidget(intro_label) @@ -99,13 +113,16 @@ def _build_ui(self) -> None: form.addRow("Header rows", self.header_rows_spin) self.q_column_combo = QComboBox() - form.addRow("q column", self.q_column_combo) + form.addRow(self._independent_column_label, self.q_column_combo) self.intensity_column_combo = QComboBox() - form.addRow("Intensity column", self.intensity_column_combo) + form.addRow( + self._dependent_column_label, + self.intensity_column_combo, + ) self.error_column_combo = QComboBox() - form.addRow("Error column", self.error_column_combo) + form.addRow(self._error_column_label, self.error_column_combo) root.addLayout(form) self.preview_box = QPlainTextEdit() diff --git a/src/saxshell/saxs/ui/experimental_overlay_window.py b/src/saxshell/saxs/ui/experimental_overlay_window.py new file mode 100644 index 0000000..8e3b2e6 --- /dev/null +++ b/src/saxshell/saxs/ui/experimental_overlay_window.py @@ -0,0 +1,999 @@ +from __future__ import annotations + +import sys +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +from matplotlib import colormaps +from matplotlib.backends.backend_qtagg import ( + FigureCanvasQTAgg, + NavigationToolbar2QT, +) +from matplotlib.colors import to_hex +from matplotlib.figure import Figure +from PySide6.QtCore import Qt +from PySide6.QtGui import QColor +from PySide6.QtWidgets import ( + QAbstractItemView, + QApplication, + QCheckBox, + QColorDialog, + QComboBox, + QDialog, + QDoubleSpinBox, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QHeaderView, + QLabel, + QMainWindow, + QMessageBox, + QPushButton, + QSizePolicy, + QSplitter, + QTableWidget, + QTableWidgetItem, + QVBoxLayout, + QWidget, +) + +from saxshell.plotting import Q_A_INVERSE_LABEL +from saxshell.saxs.project_manager import ( + ExperimentalDataSummary, + load_experimental_data_file, +) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + prepare_saxshell_application_identity, +) +from saxshell.saxs.ui.experimental_data_loader import ( + ExperimentalDataHeaderDialog, +) + + +@dataclass(slots=True) +class ExperimentalOverlayTrace: + path: Path + summary: ExperimentalDataSummary + label: str + color: str + visible: bool = True + axis: str = "left" + + +class ExperimentalDataOverlayWindow(QMainWindow): + """Overlay multiple experimental data files with shared header + parsing.""" + + SHOW_COLUMN = 0 + LABEL_COLUMN = 1 + AXIS_COLUMN = 2 + COLOR_COLUMN = 3 + POINTS_COLUMN = 4 + Q_RANGE_COLUMN = 5 + COLUMNS_COLUMN = 6 + + def __init__( + self, + *, + initial_paths: Iterable[str | Path] | None = None, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self.setWindowTitle("Experimental Data Overlay") + self.resize(1180, 760) + self.traces: list[ExperimentalOverlayTrace] = [] + self._updating_table = False + self._left_axis = None + self._right_axis = None + + self._build_ui() + self._refresh_q_range_controls() + self._refresh_trace_table() + self._refresh_plot() + + if initial_paths is not None: + self.add_data_files(initial_paths) + + def _build_ui(self) -> None: + root = QSplitter(Qt.Orientation.Horizontal) + self.setCentralWidget(root) + + controls = QWidget() + controls.setMinimumWidth(255) + controls.setMaximumWidth(340) + controls_layout = QVBoxLayout(controls) + + file_group = QGroupBox("Data Files") + file_layout = QVBoxLayout(file_group) + self.add_files_button = QPushButton("Add Data Files...") + self.add_files_button.clicked.connect(self._choose_data_files) + self.remove_files_button = QPushButton("Remove Selected") + self.remove_files_button.clicked.connect(self._remove_selected_traces) + self.clear_files_button = QPushButton("Clear Traces") + self.clear_files_button.clicked.connect(self._clear_traces) + file_layout.addWidget(self.add_files_button) + file_layout.addWidget(self.remove_files_button) + file_layout.addWidget(self.clear_files_button) + controls_layout.addWidget(file_group) + + range_group = QGroupBox("q-Range") + range_layout = QFormLayout(range_group) + self.full_q_range_checkbox = QCheckBox("Use full loaded range") + self.full_q_range_checkbox.setChecked(True) + self.full_q_range_checkbox.toggled.connect( + self._on_full_q_range_toggled + ) + range_layout.addRow(self.full_q_range_checkbox) + + self.q_min_spin = QDoubleSpinBox() + self.q_min_spin.setDecimals(6) + self.q_min_spin.setRange(-1.0e12, 1.0e12) + self.q_min_spin.setSingleStep(0.01) + self.q_min_spin.valueChanged.connect(self._on_q_range_changed) + range_layout.addRow("q min", self.q_min_spin) + + self.q_max_spin = QDoubleSpinBox() + self.q_max_spin.setDecimals(6) + self.q_max_spin.setRange(-1.0e12, 1.0e12) + self.q_max_spin.setSingleStep(0.01) + self.q_max_spin.valueChanged.connect(self._on_q_range_changed) + range_layout.addRow("q max", self.q_max_spin) + + self.use_loaded_range_button = QPushButton("Use Loaded Range") + self.use_loaded_range_button.clicked.connect(self._use_loaded_q_range) + range_layout.addRow(self.use_loaded_range_button) + controls_layout.addWidget(range_group) + + axes_group = QGroupBox("Axes") + axes_layout = QVBoxLayout(axes_group) + scale_button_row = QHBoxLayout() + self.log_x_axis_button = QPushButton("Log X: On") + self.log_x_axis_button.setCheckable(True) + self.log_x_axis_button.setChecked(True) + self.log_x_axis_button.toggled.connect(self._set_log_x_axis_enabled) + self.log_y_axis_button = QPushButton("Log Y: On") + self.log_y_axis_button.setCheckable(True) + self.log_y_axis_button.setChecked(True) + self.log_y_axis_button.toggled.connect(self._set_log_y_axis_enabled) + scale_button_row.addWidget(self.log_x_axis_button) + scale_button_row.addWidget(self.log_y_axis_button) + axes_layout.addLayout(scale_button_row) + self.align_y_axes_checkbox = QCheckBox( + "Rescale right axis to left data" + ) + self.align_y_axes_checkbox.setChecked(True) + self.align_y_axes_checkbox.toggled.connect(self._refresh_plot) + self.rescale_axes_button = QPushButton("Rescale Axes") + self.rescale_axes_button.clicked.connect( + self._rescale_axes_to_current_q_range + ) + axes_layout.addWidget(self.align_y_axes_checkbox) + axes_layout.addWidget(self.rescale_axes_button) + controls_layout.addWidget(axes_group) + + self.status_label = QLabel("Open experimental data files to overlay.") + self.status_label.setWordWrap(True) + controls_layout.addWidget(self.status_label) + controls_layout.addStretch(1) + root.addWidget(controls) + + plot_panel = QWidget() + plot_layout = QVBoxLayout(plot_panel) + self.figure = Figure(figsize=(7.6, 5.2), tight_layout=True) + self.canvas = FigureCanvasQTAgg(self.figure) + self.canvas.setSizePolicy( + QSizePolicy.Policy.Expanding, + QSizePolicy.Policy.Expanding, + ) + self.toolbar = NavigationToolbar2QT(self.canvas, self) + plot_layout.addWidget(self.toolbar) + plot_layout.addWidget(self.canvas, stretch=1) + + self.trace_table = QTableWidget(0, 7) + self.trace_table.setHorizontalHeaderLabels( + [ + "Show", + "Dataset", + "Axis", + "Color", + "Points", + "q range", + "Columns", + ] + ) + self.trace_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.trace_table.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + self.trace_table.setEditTriggers( + QAbstractItemView.EditTrigger.DoubleClicked + | QAbstractItemView.EditTrigger.SelectedClicked + ) + self.trace_table.itemChanged.connect(self._on_trace_item_changed) + self.trace_table.cellClicked.connect(self._on_trace_cell_clicked) + self.trace_table.itemSelectionChanged.connect( + self._refresh_action_state + ) + horizontal_header = self.trace_table.horizontalHeader() + horizontal_header.setSectionResizeMode( + self.SHOW_COLUMN, + QHeaderView.ResizeMode.ResizeToContents, + ) + horizontal_header.setSectionResizeMode( + self.LABEL_COLUMN, + QHeaderView.ResizeMode.Stretch, + ) + horizontal_header.setSectionResizeMode( + self.AXIS_COLUMN, + QHeaderView.ResizeMode.ResizeToContents, + ) + horizontal_header.setSectionResizeMode( + self.COLOR_COLUMN, + QHeaderView.ResizeMode.ResizeToContents, + ) + horizontal_header.setSectionResizeMode( + self.POINTS_COLUMN, + QHeaderView.ResizeMode.ResizeToContents, + ) + horizontal_header.setSectionResizeMode( + self.Q_RANGE_COLUMN, + QHeaderView.ResizeMode.ResizeToContents, + ) + horizontal_header.setSectionResizeMode( + self.COLUMNS_COLUMN, + QHeaderView.ResizeMode.Stretch, + ) + self.trace_table.setMinimumHeight(170) + plot_layout.addWidget(self.trace_table) + root.addWidget(plot_panel) + root.setStretchFactor(0, 0) + root.setStretchFactor(1, 1) + + def _choose_data_files(self) -> None: + paths, _selected_filter = QFileDialog.getOpenFileNames( + self, + "Open Experimental Data Files", + "", + "Data files (*.txt *.dat *.iq);;All files (*)", + ) + if paths: + self.add_data_files(paths) + + def add_data_files(self, paths: Iterable[str | Path]) -> int: + added = 0 + failures: list[str] = [] + for raw_path in paths: + file_path = Path(raw_path).expanduser().resolve() + summary = self._load_data_file(file_path) + if summary is None: + failures.append(file_path.name) + continue + self.traces.append( + ExperimentalOverlayTrace( + path=file_path, + summary=summary, + label=file_path.name, + color=self._next_trace_color(), + ) + ) + added += 1 + + if added: + self._refresh_q_range_controls() + self._refresh_trace_table() + self._refresh_plot() + self.status_label.setText( + f"Loaded {added} data file{'s' if added != 1 else ''}." + ) + if failures: + QMessageBox.warning( + self, + "Experimental Data Load", + "Could not load: " + ", ".join(failures), + ) + self._refresh_action_state() + return added + + def _load_data_file( + self, + file_path: Path, + ) -> ExperimentalDataSummary | None: + try: + return load_experimental_data_file(file_path, skiprows=0) + except Exception: + dialog = ExperimentalDataHeaderDialog(file_path, self) + if dialog.exec() != QDialog.DialogCode.Accepted: + return None + return dialog.accepted_summary + + def _next_trace_color(self) -> str: + color = colormaps["tab10"](len(self.traces) % 10) + return str(to_hex(color)) + + def _refresh_trace_table(self) -> None: + self._updating_table = True + self.trace_table.blockSignals(True) + try: + self.trace_table.setRowCount(len(self.traces)) + for row, trace in enumerate(self.traces): + self.trace_table.setItem( + row, + self.SHOW_COLUMN, + self._build_visibility_item(trace), + ) + self.trace_table.setItem( + row, + self.LABEL_COLUMN, + self._build_label_item(trace), + ) + self.trace_table.setCellWidget( + row, + self.AXIS_COLUMN, + self._build_axis_combo(row, trace), + ) + self.trace_table.setItem( + row, + self.COLOR_COLUMN, + self._build_color_item(trace), + ) + self.trace_table.setItem( + row, + self.POINTS_COLUMN, + self._read_only_item(str(len(trace.summary.q_values))), + ) + self.trace_table.setItem( + row, + self.Q_RANGE_COLUMN, + self._read_only_item(self._trace_q_range_text(trace)), + ) + self.trace_table.setItem( + row, + self.COLUMNS_COLUMN, + self._read_only_item(self._trace_column_text(trace)), + ) + finally: + self.trace_table.blockSignals(False) + self._updating_table = False + self._refresh_action_state() + + def _build_visibility_item( + self, + trace: ExperimentalOverlayTrace, + ) -> QTableWidgetItem: + item = QTableWidgetItem() + item.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + item.setFlags( + Qt.ItemFlag.ItemIsSelectable + | Qt.ItemFlag.ItemIsEnabled + | Qt.ItemFlag.ItemIsUserCheckable + ) + item.setCheckState( + Qt.CheckState.Checked if trace.visible else Qt.CheckState.Unchecked + ) + return item + + def _build_label_item( + self, + trace: ExperimentalOverlayTrace, + ) -> QTableWidgetItem: + item = QTableWidgetItem(trace.label) + item.setToolTip(str(trace.path)) + item.setFlags( + Qt.ItemFlag.ItemIsSelectable + | Qt.ItemFlag.ItemIsEnabled + | Qt.ItemFlag.ItemIsEditable + ) + return item + + def _build_axis_combo( + self, + row: int, + trace: ExperimentalOverlayTrace, + ) -> QComboBox: + combo = QComboBox() + combo.addItem("Left Y", "left") + combo.addItem("Right Y", "right") + combo.setCurrentIndex(combo.findData(trace.axis)) + combo.currentIndexChanged.connect( + lambda _index, trace_index=row, widget=combo: self._set_trace_axis( + trace_index, + str(widget.currentData()), + ) + ) + return combo + + def _build_color_item( + self, + trace: ExperimentalOverlayTrace, + ) -> QTableWidgetItem: + item = QTableWidgetItem(trace.color) + item.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + item.setFlags(Qt.ItemFlag.ItemIsSelectable | Qt.ItemFlag.ItemIsEnabled) + item.setBackground(QColor(trace.color)) + item.setToolTip("Click to choose a custom trace color.") + return item + + def _read_only_item(self, text: str) -> QTableWidgetItem: + item = QTableWidgetItem(text) + item.setFlags(Qt.ItemFlag.ItemIsSelectable | Qt.ItemFlag.ItemIsEnabled) + return item + + def _on_trace_item_changed(self, item: QTableWidgetItem) -> None: + if self._updating_table: + return + row = item.row() + if row < 0 or row >= len(self.traces): + return + trace = self.traces[row] + if item.column() == self.SHOW_COLUMN: + trace.visible = item.checkState() == Qt.CheckState.Checked + self._refresh_plot() + return + if item.column() == self.LABEL_COLUMN: + updated = item.text().strip() + if updated: + trace.label = updated + self._refresh_plot() + else: + self._refresh_trace_table() + + def _on_trace_cell_clicked(self, row: int, column: int) -> None: + if column != self.COLOR_COLUMN: + return + if row < 0 or row >= len(self.traces): + return + initial_color = QColor(self.traces[row].color) + chosen = QColorDialog.getColor( + initial_color, + self, + f"Choose color for {self.traces[row].label}", + ) + if chosen.isValid(): + self._set_trace_color(row, chosen.name()) + + def _set_trace_axis(self, row: int, axis: str) -> None: + if self._updating_table or row < 0 or row >= len(self.traces): + return + if axis not in {"left", "right"}: + return + self.traces[row].axis = axis + self._refresh_plot() + + def _set_trace_color(self, row: int, color: str) -> None: + if row < 0 or row >= len(self.traces): + return + self.traces[row].color = color + self._refresh_trace_table() + self._refresh_plot() + + def _remove_selected_traces(self) -> None: + selected_rows = { + index.row() + for index in self.trace_table.selectionModel().selectedRows() + } + if not selected_rows and self.trace_table.currentRow() >= 0: + selected_rows = {self.trace_table.currentRow()} + if not selected_rows: + return + for row in sorted(selected_rows, reverse=True): + if 0 <= row < len(self.traces): + del self.traces[row] + self._refresh_q_range_controls() + self._refresh_trace_table() + self._refresh_plot() + self.status_label.setText("Removed selected trace(s).") + + def _clear_traces(self) -> None: + if not self.traces: + return + self.traces.clear() + self._refresh_q_range_controls() + self._refresh_trace_table() + self._refresh_plot() + self.status_label.setText("Cleared plotted traces.") + + def _on_q_range_changed(self, *_args: object) -> None: + self._refresh_plot() + + def _on_full_q_range_toggled(self, checked: bool) -> None: + self.q_min_spin.setEnabled(not checked and bool(self.traces)) + self.q_max_spin.setEnabled(not checked and bool(self.traces)) + if checked: + self._set_spin_values_to_loaded_q_range() + self._on_q_range_changed() + + def _use_loaded_q_range(self) -> None: + self.full_q_range_checkbox.setChecked(True) + self._set_spin_values_to_loaded_q_range() + self._on_q_range_changed() + + def _refresh_q_range_controls(self) -> None: + has_traces = bool(self.traces) + self.full_q_range_checkbox.setEnabled(has_traces) + self.use_loaded_range_button.setEnabled(has_traces) + self.q_min_spin.setEnabled( + has_traces and not self.full_q_range_checkbox.isChecked() + ) + self.q_max_spin.setEnabled( + has_traces and not self.full_q_range_checkbox.isChecked() + ) + self._set_spin_values_to_loaded_q_range() + self._refresh_action_state() + + def _set_spin_values_to_loaded_q_range(self) -> None: + q_bounds = self._loaded_q_bounds() + self.q_min_spin.blockSignals(True) + self.q_max_spin.blockSignals(True) + try: + if q_bounds is None: + self.q_min_spin.setValue(0.0) + self.q_max_spin.setValue(0.0) + return + q_min, q_max = q_bounds + padding = max((q_max - q_min) * 0.05, 1.0e-9) + self.q_min_spin.setRange(q_min - padding, q_max + padding) + self.q_max_spin.setRange(q_min - padding, q_max + padding) + if self.full_q_range_checkbox.isChecked(): + self.q_min_spin.setValue(q_min) + self.q_max_spin.setValue(q_max) + else: + self.q_min_spin.setValue( + max(q_min, min(self.q_min_spin.value(), q_max)) + ) + self.q_max_spin.setValue( + min(q_max, max(self.q_max_spin.value(), q_min)) + ) + finally: + self.q_min_spin.blockSignals(False) + self.q_max_spin.blockSignals(False) + + def _loaded_q_bounds(self) -> tuple[float, float] | None: + q_segments: list[np.ndarray] = [] + for trace in self.traces: + q_values = np.asarray(trace.summary.q_values, dtype=float) + q_values = q_values[np.isfinite(q_values)] + if q_values.size: + q_segments.append(q_values) + if not q_segments: + return None + q_values = np.concatenate(q_segments) + return float(np.nanmin(q_values)), float(np.nanmax(q_values)) + + def _active_q_bounds(self) -> tuple[float, float] | None: + loaded_bounds = self._loaded_q_bounds() + if loaded_bounds is None: + return None + if self.full_q_range_checkbox.isChecked(): + return loaded_bounds + q_min, q_max = sorted( + (float(self.q_min_spin.value()), float(self.q_max_spin.value())) + ) + return q_min, q_max + + def _refresh_action_state(self) -> None: + has_traces = bool(self.traces) + has_selection = False + selection_model = self.trace_table.selectionModel() + if selection_model is not None: + has_selection = bool(selection_model.selectedRows()) + self.remove_files_button.setEnabled(has_traces and has_selection) + self.clear_files_button.setEnabled(has_traces) + has_right_axis = self._has_visible_right_axis_trace() + self.align_y_axes_checkbox.setEnabled(has_right_axis) + self.rescale_axes_button.setEnabled(has_right_axis) + + def _has_visible_right_axis_trace(self) -> bool: + return any( + trace.visible and trace.axis == "right" for trace in self.traces + ) + + def _rescale_axes_to_current_q_range(self) -> None: + if not self._has_visible_right_axis_trace(): + return + if not self.align_y_axes_checkbox.isChecked(): + self.align_y_axes_checkbox.setChecked(True) + return + self._refresh_plot() + + def _set_log_x_axis_enabled(self, checked: bool) -> None: + self.log_x_axis_button.setText(f"Log X: {'On' if checked else 'Off'}") + self._refresh_plot() + + def _set_log_y_axis_enabled(self, checked: bool) -> None: + self.log_y_axis_button.setText(f"Log Y: {'On' if checked else 'Off'}") + self._refresh_plot() + + def _refresh_plot(self) -> None: + for axis in list(self.figure.axes): + try: + axis.set_xscale("linear") + axis.set_yscale("linear") + except Exception: + continue + self.figure.clear() + self._left_axis = None + self._right_axis = None + + visible_traces = [trace for trace in self.traces if trace.visible] + if not visible_traces: + axis = self.figure.add_subplot(111) + axis.text( + 0.5, + 0.5, + "Open experimental data files to overlay traces.", + ha="center", + va="center", + transform=axis.transAxes, + wrap=True, + ) + axis.set_axis_off() + self.figure.tight_layout() + self.canvas.draw_idle() + self._refresh_action_state() + return + + self._left_axis = self.figure.add_subplot(111) + right_traces = [ + trace for trace in visible_traces if trace.axis == "right" + ] + if right_traces: + self._right_axis = self._left_axis.twinx() + + self._apply_axis_scale(self._left_axis) + if self._right_axis is not None: + self._apply_axis_scale(self._right_axis) + + plotted_lines: list[object] = [] + for trace in visible_traces: + target_axis = ( + self._right_axis + if trace.axis == "right" and self._right_axis is not None + else self._left_axis + ) + line = self._plot_trace(target_axis, trace) + if line is not None: + plotted_lines.append(line) + + q_bounds = self._active_plot_q_bounds() + if q_bounds is not None: + q_min, q_max = q_bounds + self._left_axis.set_xlim(q_min, q_max) + self._autoscale_axis_y( + self._left_axis, + q_min, + q_max, + log_scale=self._log_y_axis_enabled(), + ) + if self._right_axis is not None: + self._right_axis.set_xlim(q_min, q_max) + self._autoscale_axis_y( + self._right_axis, + q_min, + q_max, + log_scale=self._log_y_axis_enabled(), + ) + if self.align_y_axes_checkbox.isChecked(): + self._normalize_secondary_axis( + self._left_axis, + self._right_axis, + q_min, + q_max, + ) + + self._left_axis.set_xlabel(Q_A_INVERSE_LABEL) + self._left_axis.set_ylabel("Left axis intensity") + self._left_axis.grid(True, alpha=0.25, linewidth=0.8) + if self._right_axis is not None: + self._right_axis.set_ylabel("Right axis intensity") + + if plotted_lines: + self._left_axis.legend( + plotted_lines, + [line.get_label() for line in plotted_lines], + loc="best", + fontsize=9, + framealpha=0.9, + ) + + self.figure.tight_layout() + self.canvas.draw_idle() + self._refresh_action_state() + + def _apply_axis_scale(self, axis) -> None: + axis.set_xscale("log" if self._log_x_axis_enabled() else "linear") + axis.set_yscale("log" if self._log_y_axis_enabled() else "linear") + + def _log_x_axis_enabled(self) -> bool: + return self.log_x_axis_button.isChecked() + + def _log_y_axis_enabled(self) -> bool: + return self.log_y_axis_button.isChecked() + + def _active_plot_q_bounds(self) -> tuple[float, float] | None: + bounds = self._active_q_bounds() + if bounds is None or not self._log_x_axis_enabled(): + return bounds + q_min, q_max = bounds + if q_max <= 0.0: + return None + positive_segments: list[np.ndarray] = [] + for trace in self.traces: + if not trace.visible: + continue + q_values = np.asarray(trace.summary.q_values, dtype=float) + q_values = q_values[np.isfinite(q_values) & (q_values > 0.0)] + if q_values.size: + positive_segments.append(q_values) + if not positive_segments: + return None + positive_q = np.concatenate(positive_segments) + lower = max(q_min, float(np.nanmin(positive_q))) + upper = max(q_max, lower * (1.0 + 1.0e-9)) + return lower, upper + + def _plot_trace( + self, + axis, + trace: ExperimentalOverlayTrace, + ): + q_values = np.asarray(trace.summary.q_values, dtype=float) + intensities = np.asarray(trace.summary.intensities, dtype=float) + mask = np.isfinite(q_values) & np.isfinite(intensities) + if self._log_x_axis_enabled(): + mask &= q_values > 0.0 + if self._log_y_axis_enabled(): + mask &= intensities > 0.0 + if not np.any(mask): + return None + (line,) = axis.plot( + q_values[mask], + intensities[mask], + color=trace.color, + linewidth=1.6, + label=trace.label, + ) + return line + + def _normalize_secondary_axis( + self, + left_axis, + right_axis, + q_min: float, + q_max: float, + ) -> None: + log_scale = self._log_y_axis_enabled() + left_values = self._axis_y_values( + left_axis, + q_min, + q_max, + log_scale=log_scale, + ) + right_values = self._axis_y_values( + right_axis, + q_min, + q_max, + log_scale=log_scale, + ) + if left_values.size == 0 or right_values.size == 0: + return + right_limits = self._aligned_y_limits( + left_axis.get_ylim(), + float(np.nanmin(left_values)), + float(np.nanmax(left_values)), + float(np.nanmin(right_values)), + float(np.nanmax(right_values)), + log_scale=log_scale, + ) + right_axis.set_ylim(right_limits) + + @staticmethod + def _axis_y_values( + axis, + q_min: float, + q_max: float, + *, + log_scale: bool = False, + ) -> np.ndarray: + y_segments: list[np.ndarray] = [] + for line in axis.get_lines(): + if not line.get_visible(): + continue + x_data = np.asarray(line.get_xdata(orig=False), dtype=float) + y_data = np.asarray(line.get_ydata(orig=False), dtype=float) + mask = ( + np.isfinite(x_data) + & np.isfinite(y_data) + & (x_data >= q_min) + & (x_data <= q_max) + ) + if log_scale: + mask &= y_data > 0.0 + if np.any(mask): + y_segments.append(y_data[mask]) + if not y_segments: + return np.asarray([], dtype=float) + return np.concatenate(y_segments) + + @staticmethod + def _autoscale_axis_y( + axis, + q_min: float, + q_max: float, + *, + log_scale: bool, + ) -> None: + y_values = ExperimentalDataOverlayWindow._axis_y_values( + axis, + q_min, + q_max, + log_scale=log_scale, + ) + if y_values.size == 0: + return + y_min = float(np.nanmin(y_values)) + y_max = float(np.nanmax(y_values)) + if np.isclose(y_min, y_max): + if log_scale and y_min > 0.0: + axis.set_ylim(y_min / 1.15, y_max * 1.15) + else: + padding = max(abs(y_min) * 0.05, 1e-12) + axis.set_ylim(y_min - padding, y_max + padding) + return + if log_scale: + axis.set_ylim(y_min / 1.15, y_max * 1.15) + else: + padding = 0.05 * (y_max - y_min) + axis.set_ylim(y_min - padding, y_max + padding) + + @staticmethod + def _aligned_y_limits( + left_limits: tuple[float, float], + left_min: float, + left_max: float, + right_min: float, + right_max: float, + *, + log_scale: bool, + ) -> tuple[float, float]: + if log_scale: + if ( + min( + left_limits[0], + left_limits[1], + left_min, + left_max, + right_min, + right_max, + ) + <= 0.0 + ): + log_scale = False + if not log_scale: + left_low, left_high = left_limits + data_low, data_high = sorted((left_min, left_max)) + right_low_data, right_high_data = sorted((right_min, right_max)) + if np.isclose(right_high_data, right_low_data): + padding = max(abs(right_low_data) * 0.05, 1e-12) + right_low_data -= padding + right_high_data += padding + if np.isclose(left_high, left_low) or np.isclose( + data_high, + data_low, + ): + padding = max(abs(right_low_data) * 0.1, 1e-12) + return ( + right_low_data - padding, + right_high_data + padding, + ) + p0 = (data_low - left_low) / (left_high - left_low) + p1 = (data_high - left_low) / (left_high - left_low) + if np.isclose(p1, p0): + padding = max(abs(right_low_data) * 0.1, 1e-12) + return ( + right_low_data - padding, + right_high_data + padding, + ) + delta = (right_high_data - right_low_data) / (p1 - p0) + right_low = right_low_data - p0 * delta + right_high = right_low + delta + return right_low, right_high + + left_logs = np.log10(np.asarray(left_limits, dtype=float)) + data_logs = np.log10( + np.asarray(sorted((left_min, left_max)), dtype=float) + ) + right_logs = np.log10( + np.asarray(sorted((right_min, right_max)), dtype=float) + ) + if np.isclose(left_logs[1], left_logs[0]) or np.isclose( + data_logs[1], + data_logs[0], + ): + return right_min / 1.2, right_max * 1.2 + p0 = (data_logs[0] - left_logs[0]) / (left_logs[1] - left_logs[0]) + p1 = (data_logs[1] - left_logs[0]) / (left_logs[1] - left_logs[0]) + if np.isclose(p1, p0): + return right_min / 1.2, right_max * 1.2 + delta = (right_logs[1] - right_logs[0]) / (p1 - p0) + right_low_log = right_logs[0] - p0 * delta + right_high_log = right_low_log + delta + return 10**right_low_log, 10**right_high_log + + @staticmethod + def _summary_column_label( + summary: ExperimentalDataSummary, + column_index: int | None, + ) -> str: + if column_index is None: + return "None" + if 0 <= column_index < len(summary.column_names): + return summary.column_names[column_index] + return f"Column {column_index + 1}" + + def _trace_column_text(self, trace: ExperimentalOverlayTrace) -> str: + summary = trace.summary + q_label = self._summary_column_label(summary, summary.q_column) + i_label = self._summary_column_label( + summary, + summary.intensity_column, + ) + text = f"q={q_label}, I={i_label}" + if summary.error_column is not None: + error_label = self._summary_column_label( + summary, + summary.error_column, + ) + text += f", err={error_label}" + text += f"; header rows={summary.header_rows}" + return text + + @staticmethod + def _trace_q_range_text(trace: ExperimentalOverlayTrace) -> str: + q_values = np.asarray(trace.summary.q_values, dtype=float) + q_values = q_values[np.isfinite(q_values)] + if q_values.size == 0: + return "--" + q_min = float(np.nanmin(q_values)) + q_max = float(np.nanmax(q_values)) + return f"{q_min:.6g} - {q_max:.6g}" + + +_OPEN_WINDOWS: list[ExperimentalDataOverlayWindow] = [] + + +def _forget_open_window(window: ExperimentalDataOverlayWindow) -> None: + _OPEN_WINDOWS[:] = [ + existing for existing in _OPEN_WINDOWS if existing is not window + ] + + +def launch_experimental_data_overlay_ui( + *, + initial_paths: Iterable[str | Path] | None = None, +) -> ExperimentalDataOverlayWindow: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) + window = ExperimentalDataOverlayWindow(initial_paths=initial_paths) + window.show() + window.raise_() + _OPEN_WINDOWS.append(window) + window.destroyed.connect( + lambda _obj=None, win=window: _forget_open_window(win) + ) + return window + + +__all__ = [ + "ExperimentalDataOverlayWindow", + "ExperimentalOverlayTrace", + "launch_experimental_data_overlay_ui", +] diff --git a/src/saxshell/saxs/ui/main_window.py b/src/saxshell/saxs/ui/main_window.py index 5ed6417..af29ada 100644 --- a/src/saxshell/saxs/ui/main_window.py +++ b/src/saxshell/saxs/ui/main_window.py @@ -1131,6 +1131,7 @@ def __init__( self._base_font_point_size = self._resolve_base_font_point_size() self._scale_shortcuts: list[QShortcut] = [] self._child_tool_windows: list[object] = [] + self._single_instance_child_tool_windows: dict[str, object] = {} self._contrast_mode_tool_window: object | None = None self._solute_volume_fraction_tool_window: ( SoluteVolumeFractionToolWindow | None @@ -1510,6 +1511,58 @@ def _build_menu_bar(self) -> None: self.pdfsetup_action.triggered.connect(self._open_pdfsetup_tool) self.pdf_menu.addAction(self.pdfsetup_action) + self.batch_processing_menu = self.tools_menu.addMenu( + "Batch Processing" + ) + self.mdtrajectory_batch_queue_action = QAction( + "Open MD Trajectory Batch Queue", + self, + ) + self.mdtrajectory_batch_queue_action.triggered.connect( + self._open_mdtrajectory_batch_queue_tool + ) + self.batch_processing_menu.addAction( + self.mdtrajectory_batch_queue_action + ) + + self.xyz2pdb_batch_queue_action = QAction( + "Open XYZ -> PDB Batch Queue", + self, + ) + self.xyz2pdb_batch_queue_action.triggered.connect( + self._open_xyz2pdb_batch_queue_tool + ) + self.batch_processing_menu.addAction(self.xyz2pdb_batch_queue_action) + + self.cluster_batch_queue_action = QAction( + "Open Cluster Extraction Batch Queue", + self, + ) + self.cluster_batch_queue_action.triggered.connect( + self._open_cluster_batch_queue_tool + ) + self.batch_processing_menu.addAction(self.cluster_batch_queue_action) + + self.representativefinder_batch_queue_action = QAction( + "Open Representative Structures Batch Queue", + self, + ) + self.representativefinder_batch_queue_action.triggered.connect( + self._open_representative_batch_queue_tool + ) + self.batch_processing_menu.addAction( + self.representativefinder_batch_queue_action + ) + + self.pdf_batch_queue_action = QAction( + "Open PDF Batch Queue", + self, + ) + self.pdf_batch_queue_action.triggered.connect( + self._open_pdf_batch_queue_tool + ) + self.batch_processing_menu.addAction(self.pdf_batch_queue_action) + self.fullrmc_action = QAction("Open RMC Setup (fullrmc)", self) self.fullrmc_action.triggered.connect(self._open_fullrmc_tool) self.pdf_menu.addAction(self.fullrmc_action) @@ -1524,6 +1577,15 @@ def _build_menu_bar(self) -> None: ) self.visualization_menu.addAction(self.structure_viewer_action) + self.experimental_overlay_action = QAction( + "Experimental Data Overlay", + self, + ) + self.experimental_overlay_action.triggered.connect( + self._open_experimental_data_overlay_tool + ) + self.visualization_menu.addAction(self.experimental_overlay_action) + self.blenderxyz_action = QAction( "Open Blender XYZ Renderer", self, @@ -1616,6 +1678,38 @@ def _build_menu_bar(self) -> None: ) self.xray_toolkit_menu.addAction(self.fluorescence_estimate_action) self.cli_setup_menu = self.tools_menu.addMenu("CLI Setup") + self.xyz2pdb_cli_setup_action = QAction( + "Open XYZ -> PDB CLI Setup (Beta)", + self, + ) + self.xyz2pdb_cli_setup_action.triggered.connect( + self._open_xyz2pdb_cli_setup_tool + ) + self.cli_setup_menu.addAction(self.xyz2pdb_cli_setup_action) + self.cluster_cli_setup_action = QAction( + "Open Cluster Extraction CLI Setup (Beta)", + self, + ) + self.cluster_cli_setup_action.triggered.connect( + self._open_cluster_cli_setup_tool + ) + self.cli_setup_menu.addAction(self.cluster_cli_setup_action) + self.clusterdynamics_cli_setup_action = QAction( + "Open Cluster Dynamics CLI Setup (Beta)", + self, + ) + self.clusterdynamics_cli_setup_action.triggered.connect( + self._open_clusterdynamics_cli_setup_tool + ) + self.cli_setup_menu.addAction(self.clusterdynamics_cli_setup_action) + self.clusterdynamicsml_cli_setup_action = QAction( + "Open Cluster Dynamics ML CLI Setup (Beta)", + self, + ) + self.clusterdynamicsml_cli_setup_action.triggered.connect( + self._open_clusterdynamicsml_cli_setup_tool + ) + self.cli_setup_menu.addAction(self.clusterdynamicsml_cli_setup_action) self.representative_cli_setup_action = QAction( "Open Representative CLI Setup (Beta)", self, @@ -10404,8 +10498,17 @@ def _connect_debye_waller_updates(self, window: object) -> None: return signal.connect(self._on_debye_waller_project_saved) - def _track_child_tool_window(self, window: object) -> None: + def _track_child_tool_window( + self, + window: object, + *, + single_instance_key: str | None = None, + ) -> None: if window in self._child_tool_windows: + if single_instance_key is not None: + self._single_instance_child_tool_windows[ + single_instance_key + ] = window return if isinstance(window, QWidget): window.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose, True) @@ -10419,6 +10522,10 @@ def _track_child_tool_window(self, window: object) -> None: ) ) self._child_tool_windows.append(window) + if single_instance_key is not None: + self._single_instance_child_tool_windows[single_instance_key] = ( + window + ) def _forget_child_tool_window(self, window: object) -> None: self._child_tool_windows = [ @@ -10426,6 +10533,60 @@ def _forget_child_tool_window(self, window: object) -> None: for existing in self._child_tool_windows if existing is not window ] + self._single_instance_child_tool_windows = { + key: existing + for key, existing in self._single_instance_child_tool_windows.items() + if existing is not window + } + + def _focus_single_instance_child_tool_window( + self, + single_instance_key: str, + tool_label: str, + ) -> bool: + window = self._single_instance_child_tool_windows.get( + single_instance_key + ) + if window is None: + return False + try: + for method_name in ("show", "raise_", "activateWindow"): + method = getattr(window, method_name, None) + if callable(method): + method() + except RuntimeError: + self._forget_child_tool_window(window) + return False + self.statusBar().showMessage( + f"{tool_label} is already open; focused the existing window.", + 5000, + ) + return True + + def _block_conflicting_child_tool_window( + self, + conflict_key: str, + *, + requested_tool_label: str, + open_tool_label: str, + ) -> bool: + window = self._single_instance_child_tool_windows.get(conflict_key) + if window is None: + return False + try: + for method_name in ("show", "raise_", "activateWindow"): + method = getattr(window, method_name, None) + if callable(method): + method() + except RuntimeError: + self._forget_child_tool_window(window) + return False + self.statusBar().showMessage( + f"{open_tool_label} is already open; close it before opening " + f"{requested_tool_label}.", + 5000, + ) + return True def _close_child_tool_windows(self) -> bool: for window in list(self._child_tool_windows): @@ -10559,6 +10720,11 @@ def _on_debye_waller_project_saved(self, payload: object) -> None: ) def _open_mdtrajectory_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "mdtrajectory", + "MD trajectory extraction", + ): + return from saxshell.mdtrajectory.ui.main_window import ( launch_mdtrajectory_app, ) @@ -10580,7 +10746,10 @@ def _open_mdtrajectory_tool(self) -> None: energy_file=energy_file, ) self._connect_project_path_updates(window) - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="mdtrajectory", + ) if project_dir is not None: self.statusBar().showMessage( f"Opened MD trajectory extraction for {project_dir}" @@ -10588,7 +10757,52 @@ def _open_mdtrajectory_tool(self) -> None: else: self.statusBar().showMessage("Opened MD trajectory extraction") + def _open_mdtrajectory_batch_queue_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "mdtrajectory_batch_queue", + "MD trajectory batch queue", + ): + return + from saxshell.mdtrajectory.ui.batch_queue_window import ( + MDTrajectoryBatchQueueWindow, + ) + + settings = self._active_project_launch_settings() + project_dir = None + trajectory_file = None + topology_file = None + energy_file = None + if settings is not None: + project_dir = Path(settings.project_dir).resolve() + trajectory_file = settings.resolved_trajectory_file + topology_file = settings.resolved_topology_file + energy_file = settings.resolved_energy_file + window = MDTrajectoryBatchQueueWindow( + initial_project_dir=project_dir, + initial_trajectory_file=trajectory_file, + initial_topology_file=topology_file, + initial_energy_file=energy_file, + ) + self._connect_project_path_updates(window) + window.show() + window.raise_() + self._track_child_tool_window( + window, + single_instance_key="mdtrajectory_batch_queue", + ) + if project_dir is not None: + self.statusBar().showMessage( + f"Opened MD trajectory batch queue for {project_dir}" + ) + else: + self.statusBar().showMessage("Opened MD trajectory batch queue") + def _open_cluster_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "cluster", + "Cluster extraction", + ): + return from saxshell.cluster.ui.main_window import ClusterMainWindow settings = self._active_project_launch_settings() @@ -10604,7 +10818,10 @@ def _open_cluster_tool(self) -> None: self._connect_project_path_updates(window) window.show() window.raise_() - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="cluster", + ) if frames_dir is not None: self.statusBar().showMessage( f"Opened cluster extraction for {frames_dir}" @@ -10612,7 +10829,51 @@ def _open_cluster_tool(self) -> None: else: self.statusBar().showMessage("Opened cluster extraction") + def _open_cluster_batch_queue_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "cluster_batch_queue", + "Cluster extraction batch queue", + ): + return + from saxshell.cluster.ui.batch_queue_window import ( + ClusterBatchQueueWindow, + ) + + settings = self._active_project_launch_settings() + frames_dir = None + project_dir = None + if settings is not None: + frames_dir = ( + settings.resolved_pdb_frames_dir + or settings.resolved_frames_dir + ) + project_dir = Path(settings.project_dir).resolve() + window = ClusterBatchQueueWindow( + initial_project_dir=project_dir, + initial_frames_dir=frames_dir, + ) + self._connect_project_path_updates(window) + window.show() + window.raise_() + self._track_child_tool_window( + window, + single_instance_key="cluster_batch_queue", + ) + if project_dir is not None: + self.statusBar().showMessage( + f"Opened cluster extraction batch queue for {project_dir}" + ) + else: + self.statusBar().showMessage( + "Opened cluster extraction batch queue" + ) + def _open_xyz2pdb_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "xyz2pdb", + "XYZ -> PDB conversion", + ): + return from saxshell.xyz2pdb.ui.main_window import launch_xyz2pdb_ui settings = self._active_project_launch_settings() @@ -10626,7 +10887,10 @@ def _open_xyz2pdb_tool(self) -> None: project_dir=project_dir, ) self._connect_project_path_updates(window) - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="xyz2pdb", + ) if input_path is not None: self.statusBar().showMessage( f"Opened XYZ -> PDB conversion for {input_path}" @@ -10634,7 +10898,187 @@ def _open_xyz2pdb_tool(self) -> None: else: self.statusBar().showMessage("Opened XYZ -> PDB conversion") + def _open_xyz2pdb_batch_queue_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "xyz2pdb_batch_queue", + "XYZ -> PDB batch queue", + ): + return + from saxshell.xyz2pdb.ui.batch_queue_window import ( + XYZToPDBBatchQueueWindow, + ) + + settings = self._active_project_launch_settings() + input_path = None + project_dir = None + if settings is not None: + input_path = settings.resolved_frames_dir + project_dir = Path(settings.project_dir).resolve() + window = XYZToPDBBatchQueueWindow( + initial_project_dir=project_dir, + initial_input_path=input_path, + ) + self._connect_project_path_updates(window) + window.show() + window.raise_() + self._track_child_tool_window( + window, + single_instance_key="xyz2pdb_batch_queue", + ) + if project_dir is not None: + self.statusBar().showMessage( + f"Opened XYZ -> PDB batch queue for {project_dir}" + ) + else: + self.statusBar().showMessage("Opened XYZ -> PDB batch queue") + + def _open_xyz2pdb_cli_setup_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "xyz2pdb_cli_setup", + "XYZ -> PDB CLI setup", + ): + return + from saxshell.xyz2pdb.ui.run_file_window import ( + launch_xyz2pdb_run_file_ui, + ) + + settings = self._active_project_launch_settings() + input_path = None + project_dir = None + if settings is not None: + input_path = settings.resolved_frames_dir + project_dir = Path(settings.project_dir).resolve() + window = launch_xyz2pdb_run_file_ui( + initial_project_dir=project_dir, + initial_input_path=input_path, + ) + self._track_child_tool_window( + window, + single_instance_key="xyz2pdb_cli_setup", + ) + if project_dir is not None: + self.statusBar().showMessage( + "Opened XYZ -> PDB CLI setup for " f"{project_dir}" + ) + else: + self.statusBar().showMessage("Opened XYZ -> PDB CLI setup") + + def _open_cluster_cli_setup_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "cluster_cli_setup", + "Cluster extraction CLI setup", + ): + return + from saxshell.cluster.ui.run_file_window import ( + launch_cluster_run_file_ui, + ) + + settings = self._active_project_launch_settings() + frames_dir = None + project_dir = None + if settings is not None: + frames_dir = ( + settings.resolved_pdb_frames_dir + or settings.resolved_frames_dir + ) + project_dir = Path(settings.project_dir).resolve() + window = launch_cluster_run_file_ui( + initial_project_dir=project_dir, + initial_frames_dir=frames_dir, + ) + self._track_child_tool_window( + window, + single_instance_key="cluster_cli_setup", + ) + if project_dir is not None: + self.statusBar().showMessage( + "Opened cluster extraction CLI setup for " f"{project_dir}" + ) + else: + self.statusBar().showMessage("Opened cluster extraction CLI setup") + + def _open_clusterdynamics_cli_setup_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "clusterdynamics_cli_setup", + "Cluster dynamics CLI setup", + ): + return + from saxshell.clusterdynamics.ui.run_file_window import ( + launch_clusterdynamics_run_file_ui, + ) + + settings = self._active_project_launch_settings() + project_dir = None + frames_dir = None + energy_file = None + if settings is not None: + project_dir = Path(settings.project_dir).resolve() + frames_dir = settings.resolved_frames_dir + energy_file = settings.resolved_energy_file + window = launch_clusterdynamics_run_file_ui( + initial_project_dir=project_dir, + initial_frames_dir=frames_dir, + initial_energy_file=energy_file, + ) + self._track_child_tool_window( + window, + single_instance_key="clusterdynamics_cli_setup", + ) + if project_dir is not None: + self.statusBar().showMessage( + "Opened cluster dynamics CLI setup for " f"{project_dir}" + ) + else: + self.statusBar().showMessage("Opened cluster dynamics CLI setup") + + def _open_clusterdynamicsml_cli_setup_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "clusterdynamicsml_cli_setup", + "Cluster dynamics ML CLI setup", + ): + return + from saxshell.clusterdynamicsml.ui.run_file_window import ( + launch_clusterdynamicsml_run_file_ui, + ) + + settings = self._active_project_launch_settings() + project_dir = None + frames_dir = None + clusters_dir = None + energy_file = None + experimental_data_file = None + if settings is not None: + project_dir = Path(settings.project_dir).resolve() + frames_dir = settings.resolved_frames_dir + clusters_dir = settings.resolved_clusters_dir + energy_file = settings.resolved_energy_file + experimental_data_file = settings.resolved_experimental_data_path + window = launch_clusterdynamicsml_run_file_ui( + initial_project_dir=project_dir, + initial_frames_dir=frames_dir, + initial_energy_file=energy_file, + initial_clusters_dir=clusters_dir, + initial_experimental_data_file=experimental_data_file, + ) + self._track_child_tool_window( + window, + single_instance_key="clusterdynamicsml_cli_setup", + ) + if project_dir is not None: + self.statusBar().showMessage( + "Opened cluster dynamics ML CLI setup for " f"{project_dir}" + ) + else: + self.statusBar().showMessage( + "Opened cluster dynamics ML CLI setup" + ) + def _open_bondanalysis_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "bondanalysis", + "Bond analysis", + ): + return from saxshell.bondanalysis.ui.main_window import BondAnalysisMainWindow settings = self._active_project_launch_settings() @@ -10649,7 +11093,10 @@ def _open_bondanalysis_tool(self) -> None: ) window.show() window.raise_() - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="bondanalysis", + ) if clusters_dir: self.statusBar().showMessage( f"Opened bond analysis for {clusters_dir}" @@ -10658,6 +11105,11 @@ def _open_bondanalysis_tool(self) -> None: self.statusBar().showMessage("Opened bond analysis") def _open_debye_waller_analysis_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "debye_waller_analysis", + "Debye-Waller analysis", + ): + return from saxshell.saxs.debye_waller.ui.main_window import ( DEBYE_WALLER_WINDOW_LOAD_TOTAL_STEPS, launch_debye_waller_analysis_ui, @@ -10725,7 +11177,10 @@ def on_startup_log(message: str) -> None: self._close_progress_dialog() self._connect_project_path_updates(window) self._connect_debye_waller_updates(window) - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="debye_waller_analysis", + ) if clusters_dir is not None: self.statusBar().showMessage( f"Opened Debye-Waller analysis for {clusters_dir}" @@ -10734,6 +11189,11 @@ def on_startup_log(message: str) -> None: self.statusBar().showMessage("Opened Debye-Waller analysis") def _open_clusterdynamics_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "clusterdynamics", + "Cluster dynamics", + ): + return from saxshell.clusterdynamics.ui.main_window import ( ClusterDynamicsMainWindow, ) @@ -10753,7 +11213,10 @@ def _open_clusterdynamics_tool(self) -> None: ) window.show() window.raise_() - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="clusterdynamics", + ) if project_dir is not None: self.statusBar().showMessage( f"Opened cluster dynamics for {project_dir}" @@ -10762,6 +11225,11 @@ def _open_clusterdynamics_tool(self) -> None: self.statusBar().showMessage("Opened cluster dynamics") def _open_clusterdynamicsml_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "clusterdynamicsml", + "Cluster dynamics (ML)", + ): + return from saxshell.clusterdynamicsml.ui.main_window import ( ClusterDynamicsMLMainWindow, ) @@ -10787,7 +11255,10 @@ def _open_clusterdynamicsml_tool(self) -> None: ) window.show() window.raise_() - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="clusterdynamicsml", + ) if project_dir is not None: self.statusBar().showMessage( f"Opened cluster dynamics (ML) for {project_dir}" @@ -10796,6 +11267,11 @@ def _open_clusterdynamicsml_tool(self) -> None: self.statusBar().showMessage("Opened cluster dynamics (ML)") def _open_fullrmc_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "fullrmc", + "RMC setup", + ): + return from saxshell.fullrmc.ui.main_window import RMCSetupMainWindow project_dir = None @@ -10804,7 +11280,10 @@ def _open_fullrmc_tool(self) -> None: window = RMCSetupMainWindow(initial_project_dir=project_dir) window.show() window.raise_() - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="fullrmc", + ) if project_dir: self.statusBar().showMessage( f"Opened fullrmc setup for {project_dir}" @@ -10813,6 +11292,17 @@ def _open_fullrmc_tool(self) -> None: self.statusBar().showMessage("Opened fullrmc setup") def _open_pdfsetup_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "pdfsetup", + "PDF calculation", + ): + return + if self._block_conflicting_child_tool_window( + "pdf_batch_queue", + requested_tool_label="PDF calculation", + open_tool_label="PDF batch queue", + ): + return from saxshell.pdf.debyer.ui.main_window import DebyerPDFMainWindow settings = self._active_project_launch_settings() @@ -10827,7 +11317,10 @@ def _open_pdfsetup_tool(self) -> None: ) window.show() window.raise_() - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="pdfsetup", + ) if project_dir is not None: self.statusBar().showMessage( f"Opened PDF calculation for {project_dir}" @@ -10835,6 +11328,45 @@ def _open_pdfsetup_tool(self) -> None: else: self.statusBar().showMessage("Opened PDF calculation") + def _open_pdf_batch_queue_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "pdf_batch_queue", + "PDF batch queue", + ): + return + if self._block_conflicting_child_tool_window( + "pdfsetup", + requested_tool_label="PDF batch queue", + open_tool_label="PDF calculation", + ): + return + from saxshell.pdf.debyer.ui.batch_queue_window import ( + DebyerPDFBatchQueueWindow, + ) + + settings = self._active_project_launch_settings() + project_dir = None + frames_dir = None + if settings is not None: + project_dir = Path(settings.project_dir).resolve() + frames_dir = settings.resolved_frames_dir + window = DebyerPDFBatchQueueWindow( + initial_project_dir=project_dir, + initial_frames_dir=frames_dir, + ) + window.show() + window.raise_() + self._track_child_tool_window( + window, + single_instance_key="pdf_batch_queue", + ) + if project_dir is not None: + self.statusBar().showMessage( + f"Opened PDF batch queue for {project_dir}" + ) + else: + self.statusBar().showMessage("Opened PDF batch queue") + def _open_blenderxyz_tool(self) -> None: from saxshell.toolbox.blender.ui.main_window import ( launch_blender_xyz_renderer_ui, @@ -10853,7 +11385,21 @@ def _open_structure_viewer_tool(self) -> None: self._track_child_tool_window(window) self.statusBar().showMessage("Opened Structure Viewer") + def _open_experimental_data_overlay_tool(self) -> None: + from saxshell.saxs.ui.experimental_overlay_window import ( + launch_experimental_data_overlay_ui, + ) + + window = launch_experimental_data_overlay_ui() + self._track_child_tool_window(window) + self.statusBar().showMessage("Opened experimental data overlay") + def _open_solvent_shell_builder_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "solvent_shell_builder", + "Solvent shell builder", + ): + return from saxshell.fullrmc.ui.solvent_shell_builder_window import ( launch_solvent_shell_builder_ui, ) @@ -10870,7 +11416,10 @@ def _open_solvent_shell_builder_tool(self) -> None: initial_project_dir=project_dir, initial_input_path=initial_input_path, ) - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="solvent_shell_builder", + ) if project_dir is not None: self.statusBar().showMessage( f"Opened solvent shell builder (beta) for {project_dir}" @@ -10878,7 +11427,56 @@ def _open_solvent_shell_builder_tool(self) -> None: else: self.statusBar().showMessage("Opened solvent shell builder (beta)") + def _open_representative_batch_queue_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "representativefinder_batch_queue", + "Representative structures batch queue", + ): + return + from saxshell.representativefinder.ui.batch_queue_window import ( + RepresentativeFinderBatchQueueWindow, + ) + + project_dir = None + initial_clusters_dir = None + if self.current_settings is not None: + project_dir = Path(self.current_settings.project_dir).resolve() + initial_clusters_dir = self.current_settings.resolved_clusters_dir + window = RepresentativeFinderBatchQueueWindow( + initial_project_dir=project_dir, + initial_clusters_dir=initial_clusters_dir, + ) + project_results_changed = getattr( + window, "project_results_changed", None + ) + if project_results_changed is not None and hasattr( + project_results_changed, "connect" + ): + project_results_changed.connect( + self._handle_representative_structure_results_changed + ) + window.show() + window.raise_() + self._track_child_tool_window( + window, + single_instance_key="representativefinder_batch_queue", + ) + if project_dir is not None: + self.statusBar().showMessage( + "Opened representative structures batch queue for " + f"{project_dir}" + ) + else: + self.statusBar().showMessage( + "Opened representative structures batch queue" + ) + def _open_representative_finder_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "representativefinder", + "Representative structures", + ): + return from saxshell.representativefinder.ui.main_window import ( launch_representativefinder_ui, ) @@ -10901,7 +11499,10 @@ def _open_representative_finder_tool(self) -> None: project_results_changed.connect( self._handle_representative_structure_results_changed ) - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="representativefinder", + ) if project_dir is not None: self.statusBar().showMessage( "Opened representative structures for " f"{project_dir}" @@ -10910,6 +11511,11 @@ def _open_representative_finder_tool(self) -> None: self.statusBar().showMessage("Opened representative structures") def _open_representative_cli_setup_tool(self) -> None: + if self._focus_single_instance_child_tool_window( + "representative_cli_setup", + "Representative CLI setup", + ): + return from saxshell.representativefinder.ui.run_file_window import ( launch_representativefinder_run_file_ui, ) @@ -10923,7 +11529,10 @@ def _open_representative_cli_setup_tool(self) -> None: initial_project_dir=project_dir, initial_input_path=initial_input_path, ) - self._track_child_tool_window(window) + self._track_child_tool_window( + window, + single_instance_key="representative_cli_setup", + ) if project_dir is not None: self.statusBar().showMessage( "Opened representative CLI setup for " f"{project_dir}" diff --git a/src/saxshell/ui/__init__.py b/src/saxshell/ui/__init__.py new file mode 100644 index 0000000..6f94c51 --- /dev/null +++ b/src/saxshell/ui/__init__.py @@ -0,0 +1,17 @@ +"""Shared Qt widgets for SAXSShell applications.""" + +from .periodic_table import ( + PERIODIC_TABLE_ELEMENTS, + PeriodicElement, + PeriodicTableElementDialog, + PeriodicTableWidget, + element_by_symbol, +) + +__all__ = [ + "PERIODIC_TABLE_ELEMENTS", + "PeriodicElement", + "PeriodicTableElementDialog", + "PeriodicTableWidget", + "element_by_symbol", +] diff --git a/src/saxshell/ui/periodic_table.py b/src/saxshell/ui/periodic_table.py new file mode 100644 index 0000000..f2a742c --- /dev/null +++ b/src/saxshell/ui/periodic_table.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from PySide6.QtCore import Signal +from PySide6.QtWidgets import ( + QDialog, + QGridLayout, + QLabel, + QToolButton, + QVBoxLayout, + QWidget, +) + + +@dataclass(frozen=True, slots=True) +class PeriodicElement: + symbol: str + name: str + period: int + group: int + + +_PERIODIC_TABLE_LAYOUT = ( + ("H", "Hydrogen", 1, 1), + ("He", "Helium", 1, 18), + ("Li", "Lithium", 2, 1), + ("Be", "Beryllium", 2, 2), + ("B", "Boron", 2, 13), + ("C", "Carbon", 2, 14), + ("N", "Nitrogen", 2, 15), + ("O", "Oxygen", 2, 16), + ("F", "Fluorine", 2, 17), + ("Ne", "Neon", 2, 18), + ("Na", "Sodium", 3, 1), + ("Mg", "Magnesium", 3, 2), + ("Al", "Aluminum", 3, 13), + ("Si", "Silicon", 3, 14), + ("P", "Phosphorus", 3, 15), + ("S", "Sulfur", 3, 16), + ("Cl", "Chlorine", 3, 17), + ("Ar", "Argon", 3, 18), + ("K", "Potassium", 4, 1), + ("Ca", "Calcium", 4, 2), + ("Sc", "Scandium", 4, 3), + ("Ti", "Titanium", 4, 4), + ("V", "Vanadium", 4, 5), + ("Cr", "Chromium", 4, 6), + ("Mn", "Manganese", 4, 7), + ("Fe", "Iron", 4, 8), + ("Co", "Cobalt", 4, 9), + ("Ni", "Nickel", 4, 10), + ("Cu", "Copper", 4, 11), + ("Zn", "Zinc", 4, 12), + ("Ga", "Gallium", 4, 13), + ("Ge", "Germanium", 4, 14), + ("As", "Arsenic", 4, 15), + ("Se", "Selenium", 4, 16), + ("Br", "Bromine", 4, 17), + ("Kr", "Krypton", 4, 18), + ("Rb", "Rubidium", 5, 1), + ("Sr", "Strontium", 5, 2), + ("Y", "Yttrium", 5, 3), + ("Zr", "Zirconium", 5, 4), + ("Nb", "Niobium", 5, 5), + ("Mo", "Molybdenum", 5, 6), + ("Tc", "Technetium", 5, 7), + ("Ru", "Ruthenium", 5, 8), + ("Rh", "Rhodium", 5, 9), + ("Pd", "Palladium", 5, 10), + ("Ag", "Silver", 5, 11), + ("Cd", "Cadmium", 5, 12), + ("In", "Indium", 5, 13), + ("Sn", "Tin", 5, 14), + ("Sb", "Antimony", 5, 15), + ("Te", "Tellurium", 5, 16), + ("I", "Iodine", 5, 17), + ("Xe", "Xenon", 5, 18), + ("Cs", "Cesium", 6, 1), + ("Ba", "Barium", 6, 2), + ("La", "Lanthanum", 6, 3), + ("Hf", "Hafnium", 6, 4), + ("Ta", "Tantalum", 6, 5), + ("W", "Tungsten", 6, 6), + ("Re", "Rhenium", 6, 7), + ("Os", "Osmium", 6, 8), + ("Ir", "Iridium", 6, 9), + ("Pt", "Platinum", 6, 10), + ("Au", "Gold", 6, 11), + ("Hg", "Mercury", 6, 12), + ("Tl", "Thallium", 6, 13), + ("Pb", "Lead", 6, 14), + ("Bi", "Bismuth", 6, 15), + ("Po", "Polonium", 6, 16), + ("At", "Astatine", 6, 17), + ("Rn", "Radon", 6, 18), + ("Fr", "Francium", 7, 1), + ("Ra", "Radium", 7, 2), + ("Ac", "Actinium", 7, 3), + ("Rf", "Rutherfordium", 7, 4), + ("Db", "Dubnium", 7, 5), + ("Sg", "Seaborgium", 7, 6), + ("Bh", "Bohrium", 7, 7), + ("Hs", "Hassium", 7, 8), + ("Mt", "Meitnerium", 7, 9), + ("Ds", "Darmstadtium", 7, 10), + ("Rg", "Roentgenium", 7, 11), + ("Cn", "Copernicium", 7, 12), + ("Nh", "Nihonium", 7, 13), + ("Fl", "Flerovium", 7, 14), + ("Mc", "Moscovium", 7, 15), + ("Lv", "Livermorium", 7, 16), + ("Ts", "Tennessine", 7, 17), + ("Og", "Oganesson", 7, 18), + ("Ce", "Cerium", 9, 4), + ("Pr", "Praseodymium", 9, 5), + ("Nd", "Neodymium", 9, 6), + ("Pm", "Promethium", 9, 7), + ("Sm", "Samarium", 9, 8), + ("Eu", "Europium", 9, 9), + ("Gd", "Gadolinium", 9, 10), + ("Tb", "Terbium", 9, 11), + ("Dy", "Dysprosium", 9, 12), + ("Ho", "Holmium", 9, 13), + ("Er", "Erbium", 9, 14), + ("Tm", "Thulium", 9, 15), + ("Yb", "Ytterbium", 9, 16), + ("Lu", "Lutetium", 9, 17), + ("Th", "Thorium", 10, 4), + ("Pa", "Protactinium", 10, 5), + ("U", "Uranium", 10, 6), + ("Np", "Neptunium", 10, 7), + ("Pu", "Plutonium", 10, 8), + ("Am", "Americium", 10, 9), + ("Cm", "Curium", 10, 10), + ("Bk", "Berkelium", 10, 11), + ("Cf", "Californium", 10, 12), + ("Es", "Einsteinium", 10, 13), + ("Fm", "Fermium", 10, 14), + ("Md", "Mendelevium", 10, 15), + ("No", "Nobelium", 10, 16), + ("Lr", "Lawrencium", 10, 17), +) + +PERIODIC_TABLE_ELEMENTS = tuple( + PeriodicElement( + symbol=symbol, + name=name, + period=period, + group=group, + ) + for symbol, name, period, group in _PERIODIC_TABLE_LAYOUT +) + + +def element_by_symbol(symbol: str) -> PeriodicElement | None: + normalized = _normalized_symbol(symbol) + return next( + ( + element + for element in PERIODIC_TABLE_ELEMENTS + if element.symbol == normalized + ), + None, + ) + + +class PeriodicTableWidget(QWidget): + element_selected = Signal(str) + + def __init__( + self, + parent: QWidget | None = None, + *, + initial_symbol: str | None = None, + ) -> None: + super().__init__(parent) + self._selected_symbol: str | None = None + self._buttons: dict[str, QToolButton] = {} + layout = QGridLayout(self) + layout.setHorizontalSpacing(4) + layout.setVerticalSpacing(4) + for element in PERIODIC_TABLE_ELEMENTS: + button = QToolButton() + button.setText(element.symbol) + button.setToolTip(f"{element.name} ({element.symbol})") + button.setCheckable(True) + button.setMinimumSize(38, 32) + button.clicked.connect( + lambda _checked=False, symbol=element.symbol: ( + self.select_element(symbol) + ) + ) + self._buttons[element.symbol] = button + layout.addWidget(button, element.period - 1, element.group - 1) + if initial_symbol is not None: + self.select_element(initial_symbol, emit=False) + + def selected_symbol(self) -> str | None: + return self._selected_symbol + + def select_element(self, symbol: str, *, emit: bool = True) -> None: + element = element_by_symbol(symbol) + if element is None: + raise ValueError(f"Unknown element symbol: {symbol}") + self._selected_symbol = element.symbol + for button_symbol, button in self._buttons.items(): + button.setChecked(button_symbol == element.symbol) + if emit: + self.element_selected.emit(element.symbol) + + +class PeriodicTableElementDialog(QDialog): + def __init__( + self, + parent: QWidget | None = None, + *, + title: str = "Select Element", + initial_symbol: str | None = None, + ) -> None: + super().__init__(parent) + self.setWindowTitle(title) + self._selected_symbol: str | None = None + layout = QVBoxLayout(self) + label = QLabel("Choose an element") + layout.addWidget(label) + self.periodic_table = PeriodicTableWidget( + self, + initial_symbol=initial_symbol, + ) + self.periodic_table.element_selected.connect( + self._handle_element_selected + ) + layout.addWidget(self.periodic_table) + + def selected_symbol(self) -> str | None: + return self._selected_symbol or self.periodic_table.selected_symbol() + + @classmethod + def get_element_symbol( + cls, + *, + parent: QWidget | None = None, + title: str = "Select Element", + initial_symbol: str | None = None, + ) -> str | None: + dialog = cls( + parent, + title=title, + initial_symbol=initial_symbol, + ) + if dialog.exec() == QDialog.DialogCode.Accepted: + return dialog.selected_symbol() + return None + + def _handle_element_selected(self, symbol: str) -> None: + self._selected_symbol = symbol + self.accept() + + +def _normalized_symbol(symbol: str) -> str: + text = "".join(char for char in str(symbol).strip() if char.isalpha()) + if not text: + return "" + if len(text) == 1: + return text.upper() + return text[0].upper() + text[1:].lower() + + +__all__ = [ + "PERIODIC_TABLE_ELEMENTS", + "PeriodicElement", + "PeriodicTableElementDialog", + "PeriodicTableWidget", + "element_by_symbol", +] diff --git a/src/saxshell/xyz2pdb/__init__.py b/src/saxshell/xyz2pdb/__init__.py index 36b73aa..7399201 100644 --- a/src/saxshell/xyz2pdb/__init__.py +++ b/src/saxshell/xyz2pdb/__init__.py @@ -1,5 +1,17 @@ """Headless and Qt interfaces for xyz-to-pdb conversion workflows.""" +from .run_config import ( + DEFAULT_RUN_FILE_NAME, + XYZToPDBRunConfig, + XYZToPDBRunExecutionSummary, + build_xyz2pdb_run_config, + default_xyz2pdb_run_file_path, + load_xyz2pdb_run_config, + path_text_for_run_config, + resolve_run_config_path, + run_xyz2pdb_run_config, + save_xyz2pdb_run_config, +) from .workflow import ( AnchorPairDefinition, ConvertedResidue, @@ -38,10 +50,20 @@ "XYZToPDBInspectionResult", "XYZToPDBPreviewResult", "XYZToPDBWorkflow", + "XYZToPDBRunConfig", + "XYZToPDBRunExecutionSummary", + "DEFAULT_RUN_FILE_NAME", + "build_xyz2pdb_run_config", "create_reference_molecule", "default_reference_library_dir", + "default_xyz2pdb_run_file_path", "list_reference_library", + "load_xyz2pdb_run_config", "next_available_output_dir", + "path_text_for_run_config", + "resolve_run_config_path", "resolve_reference_path", + "run_xyz2pdb_run_config", + "save_xyz2pdb_run_config", "suggest_output_dir", ] diff --git a/src/saxshell/xyz2pdb/cli.py b/src/saxshell/xyz2pdb/cli.py index 4489a15..5681891 100644 --- a/src/saxshell/xyz2pdb/cli.py +++ b/src/saxshell/xyz2pdb/cli.py @@ -6,6 +6,11 @@ from saxshell.version import __version__ +from .run_config import ( + default_xyz2pdb_run_file_path, + load_xyz2pdb_run_config, + run_xyz2pdb_run_config, +) from .workflow import ( XYZToPDBWorkflow, create_reference_molecule, @@ -31,6 +36,24 @@ def build_parser() -> argparse.ArgumentParser: subparsers = parser.add_subparsers(dest="command") + setup_ui_parser = subparsers.add_parser( + "setup-ui", + help="Launch the project-backed run-file setup UI.", + ) + setup_ui_parser.add_argument( + "project_dir", + nargs="?", + type=Path, + help="Optional SAXSShell project folder.", + ) + setup_ui_parser.add_argument( + "--input-path", + type=Path, + default=None, + help="Optional XYZ file or folder to prefill.", + ) + setup_ui_parser.set_defaults(handler=_handle_setup_ui) + ui_parser = subparsers.add_parser("ui", help="Launch the Qt UI.") ui_parser.add_argument( "input_path", @@ -83,6 +106,26 @@ def build_parser() -> argparse.ArgumentParser: ) export_parser.set_defaults(handler=_handle_export) + run_parser = subparsers.add_parser( + "run", + help="Run XYZ -> PDB conversion from a project run file.", + ) + run_parser.add_argument( + "project_dir", + type=Path, + help="SAXSShell project folder containing the run file.", + ) + run_parser.add_argument( + "--run-file", + type=Path, + default=None, + help=( + "Run file path. Defaults to xyz2pdb_cli_run.json in the project " + "folder." + ), + ) + run_parser.set_defaults(handler=_handle_run) + references_parser = subparsers.add_parser( "references", help="List bundled reference molecules or add a new one.", @@ -204,6 +247,28 @@ def _handle_ui(args: argparse.Namespace) -> int: return 0 +def _handle_setup_ui(args: argparse.Namespace) -> int: + from PySide6.QtWidgets import QApplication + + from saxshell.saxs.ui.branding import prepare_saxshell_application_identity + + from .ui.run_file_window import launch_xyz2pdb_run_file_ui + + app = QApplication.instance() + created_app = app is None + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + launch_xyz2pdb_run_file_ui( + initial_project_dir=getattr(args, "project_dir", None), + initial_input_path=getattr(args, "input_path", None), + ) + if created_app: + assert app is not None + return int(app.exec()) + return 0 + + def _build_workflow(args: argparse.Namespace) -> XYZToPDBWorkflow: return XYZToPDBWorkflow( input_path=args.input_path, @@ -311,6 +376,28 @@ def _handle_export(args: argparse.Namespace) -> int: return 0 +def _handle_run(args: argparse.Namespace) -> int: + project_dir = Path(args.project_dir).expanduser().resolve() + run_file = _resolve_run_file(project_dir, args.run_file) + config = load_xyz2pdb_run_config(run_file) + summary = run_xyz2pdb_run_config( + project_dir, + config, + run_file_path=run_file, + log_callback=print, + progress_callback=_print_progress, + ) + print("") + print("XYZ to PDB project run complete.") + print(f"Output directory: {summary.output_dir}") + print(f"Files written: {summary.written_count}") + if summary.written_files: + print(f"First file: {summary.written_files[0]}") + print(f"Last file: {summary.written_files[-1]}") + print(f"Project file updated: {summary.project_file}") + return 0 + + def _handle_reference_list(args: argparse.Namespace) -> int: library_dir = ( default_reference_library_dir() @@ -342,3 +429,13 @@ def _handle_reference_add(args: argparse.Namespace) -> int: print(f"Residue name: {result.residue_name}") print(f"Atom count: {result.atom_count}") return 0 + + +def _resolve_run_file(project_dir: Path, run_file: Path | None) -> Path: + if run_file is None: + return default_xyz2pdb_run_file_path(project_dir) + return Path(run_file).expanduser().resolve() + + +def _print_progress(processed: int, total: int, message: str) -> None: + print(f"{processed}/{total} {message}") diff --git a/src/saxshell/xyz2pdb/reference_library/dmso_md.json b/src/saxshell/xyz2pdb/reference_library/dmso_md.json new file mode 100644 index 0000000..bffb9d2 --- /dev/null +++ b/src/saxshell/xyz2pdb/reference_library/dmso_md.json @@ -0,0 +1,4 @@ +{ + "backbone_pairs": [["O1", "S1"]], + "residue_name": "DMS" +} diff --git a/src/saxshell/xyz2pdb/reference_library/dmso_md.pdb b/src/saxshell/xyz2pdb/reference_library/dmso_md.pdb new file mode 100644 index 0000000..f5b3422 --- /dev/null +++ b/src/saxshell/xyz2pdb/reference_library/dmso_md.pdb @@ -0,0 +1,10 @@ +ATOM 1 S1 DMS X 1 15.320 -0.182 -1.718 1.00 0.00 S +ATOM 2 C1 DMS X 1 15.808 -1.950 -1.576 1.00 0.00 C +ATOM 3 H1 DMS X 1 16.777 -1.984 -1.035 1.00 0.00 H +ATOM 4 H2 DMS X 1 15.068 -2.654 -1.197 1.00 0.00 H +ATOM 5 H3 DMS X 1 16.088 -2.155 -2.607 1.00 0.00 H +ATOM 6 C2 DMS X 1 15.580 0.404 -0.053 1.00 0.00 C +ATOM 7 H4 DMS X 1 14.793 -0.035 0.558 1.00 0.00 H +ATOM 8 H5 DMS X 1 16.569 0.136 0.380 1.00 0.00 H +ATOM 9 H6 DMS X 1 15.561 1.491 -0.234 1.00 0.00 H +ATOM 10 O1 DMS X 1 13.776 -0.152 -1.909 1.00 0.00 O diff --git a/src/saxshell/xyz2pdb/run_config.py b/src/saxshell/xyz2pdb/run_config.py new file mode 100644 index 0000000..df87c14 --- /dev/null +++ b/src/saxshell/xyz2pdb/run_config.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Callable + +from saxshell.xyz2pdb.mapping_workflow import ( + FreeAtomMappingInput, + MoleculeMappingInput, + ReferenceBondToleranceInput, + XYZToPDBMappingWorkflow, +) + +DEFAULT_RUN_FILE_NAME = "xyz2pdb_cli_run.json" +RUN_CONFIG_VERSION = 1 +XYZToPDBRunLogCallback = Callable[[str], None] +XYZToPDBRunProgressCallback = Callable[[int, int, str], None] + + +@dataclass(slots=True) +class XYZToPDBRunConfig: + input_path: str + output_dir: str | None + reference_library_dir: str | None = None + molecule_inputs: tuple[MoleculeMappingInput, ...] = () + free_atom_inputs: tuple[FreeAtomMappingInput, ...] = () + hydrogen_mode: str = "leave_unassigned" + selected_solution_index: int = 0 + assertion_mode: bool = False + pbc_params: dict[str, float | str] = field(default_factory=dict) + created_at: str = field( + default_factory=lambda: datetime.now().isoformat(timespec="seconds") + ) + + def to_dict(self) -> dict[str, object]: + return { + "version": RUN_CONFIG_VERSION, + "created_at": self.created_at, + "input_path": self.input_path, + "output_dir": self.output_dir, + "reference_library_dir": self.reference_library_dir, + "molecule_inputs": [ + _molecule_input_to_dict(item) for item in self.molecule_inputs + ], + "free_atom_inputs": [ + _free_atom_input_to_dict(item) + for item in self.free_atom_inputs + ], + "hydrogen_mode": self.hydrogen_mode, + "selected_solution_index": int(self.selected_solution_index), + "assertion_mode": bool(self.assertion_mode), + "pbc_params": dict(self.pbc_params), + } + + @classmethod + def from_dict(cls, payload: dict[str, object]) -> "XYZToPDBRunConfig": + input_path = str(payload.get("input_path", "")).strip() + if not input_path: + raise ValueError("XYZ-to-PDB run file is missing input_path.") + return cls( + input_path=input_path, + output_dir=_optional_text(payload.get("output_dir")), + reference_library_dir=_optional_text( + payload.get("reference_library_dir") + ), + molecule_inputs=tuple( + _molecule_input_from_dict(item) + for item in _dict_items(payload.get("molecule_inputs")) + ), + free_atom_inputs=tuple( + _free_atom_input_from_dict(item) + for item in _dict_items(payload.get("free_atom_inputs")) + ), + hydrogen_mode=str( + payload.get("hydrogen_mode", "leave_unassigned") + ).strip() + or "leave_unassigned", + selected_solution_index=max( + int(payload.get("selected_solution_index", 0)), + 0, + ), + assertion_mode=bool(payload.get("assertion_mode", False)), + pbc_params=_pbc_params_from_dict(payload.get("pbc_params")), + created_at=str(payload.get("created_at", "")).strip() + or datetime.now().isoformat(timespec="seconds"), + ) + + +@dataclass(slots=True, frozen=True) +class XYZToPDBRunExecutionSummary: + project_dir: Path + run_file_path: Path | None + output_dir: Path + written_files: tuple[Path, ...] + project_file: Path + + @property + def written_count(self) -> int: + return len(self.written_files) + + +def default_xyz2pdb_run_file_path(project_dir: str | Path) -> Path: + return Path(project_dir).expanduser().resolve() / DEFAULT_RUN_FILE_NAME + + +def save_xyz2pdb_run_config( + output_path: str | Path, + config: XYZToPDBRunConfig, +) -> Path: + path = Path(output_path).expanduser().resolve() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(config.to_dict(), indent=2) + "\n", + encoding="utf-8", + ) + return path + + +def load_xyz2pdb_run_config(run_file_path: str | Path) -> XYZToPDBRunConfig: + path = Path(run_file_path).expanduser().resolve() + payload = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError(f"XYZ-to-PDB run file must be a JSON object: {path}") + return XYZToPDBRunConfig.from_dict(payload) + + +def path_text_for_run_config( + path: str | Path | None, + *, + project_dir: str | Path, +) -> str | None: + if path is None: + return None + resolved_project_dir = Path(project_dir).expanduser().resolve() + resolved_path = Path(path).expanduser().resolve() + try: + return resolved_path.relative_to(resolved_project_dir).as_posix() + except ValueError: + return str(resolved_path) + + +def resolve_run_config_path( + path_text: str | None, + *, + project_dir: str | Path, +) -> Path | None: + text = str(path_text or "").strip() + if not text: + return None + path = Path(text).expanduser() + if not path.is_absolute(): + path = Path(project_dir).expanduser().resolve() / path + return path.resolve() + + +def build_xyz2pdb_run_config( + *, + project_dir: str | Path, + input_path: str | Path, + output_dir: str | Path | None, + reference_library_dir: str | Path | None = None, + molecule_inputs: tuple[MoleculeMappingInput, ...] = (), + free_atom_inputs: tuple[FreeAtomMappingInput, ...] = (), + hydrogen_mode: str = "leave_unassigned", + selected_solution_index: int = 0, + assertion_mode: bool = False, + pbc_params: dict[str, float | str] | None = None, +) -> XYZToPDBRunConfig: + return XYZToPDBRunConfig( + input_path=path_text_for_run_config( + input_path, + project_dir=project_dir, + ) + or "", + output_dir=path_text_for_run_config( + output_dir, + project_dir=project_dir, + ), + reference_library_dir=path_text_for_run_config( + reference_library_dir, + project_dir=project_dir, + ), + molecule_inputs=tuple(molecule_inputs), + free_atom_inputs=tuple(free_atom_inputs), + hydrogen_mode=hydrogen_mode, + selected_solution_index=max(int(selected_solution_index), 0), + assertion_mode=bool(assertion_mode), + pbc_params=dict(pbc_params or {}), + ) + + +def run_xyz2pdb_run_config( + project_dir: str | Path, + config: XYZToPDBRunConfig, + *, + run_file_path: str | Path | None = None, + log_callback: XYZToPDBRunLogCallback | None = None, + progress_callback: XYZToPDBRunProgressCallback | None = None, +) -> XYZToPDBRunExecutionSummary: + from saxshell.saxs.project_manager import SAXSProjectManager + + resolved_project_dir = Path(project_dir).expanduser().resolve() + input_path = resolve_run_config_path( + config.input_path, + project_dir=resolved_project_dir, + ) + if input_path is None: + raise ValueError("XYZ-to-PDB run file is missing input_path.") + output_dir = resolve_run_config_path( + config.output_dir, + project_dir=resolved_project_dir, + ) + reference_library_dir = resolve_run_config_path( + config.reference_library_dir, + project_dir=resolved_project_dir, + ) + _emit_log(log_callback, f"Starting XYZ-to-PDB conversion: {input_path}") + workflow = XYZToPDBMappingWorkflow( + input_path, + reference_library_dir=reference_library_dir, + output_dir=output_dir, + ) + result = workflow.export_with_mapping( + molecule_inputs=config.molecule_inputs, + free_atom_inputs=config.free_atom_inputs, + hydrogen_mode=config.hydrogen_mode, + pbc_params=config.pbc_params, + selected_solution_index=config.selected_solution_index, + output_dir=output_dir, + assert_molecule_shapes=config.assertion_mode, + progress_callback=progress_callback, + log_callback=log_callback, + ) + manager = SAXSProjectManager() + settings = manager.load_project(resolved_project_dir) + settings.pdb_frames_dir = str(result.output_dir.expanduser().resolve()) + project_file = manager.save_project(settings) + _emit_log( + log_callback, + "Registered converted PDB frames with project: " + f"{settings.pdb_frames_dir}", + ) + return XYZToPDBRunExecutionSummary( + project_dir=resolved_project_dir, + run_file_path=( + None if run_file_path is None else Path(run_file_path).resolve() + ), + output_dir=result.output_dir, + written_files=result.written_files, + project_file=project_file, + ) + + +def _molecule_input_to_dict(item: MoleculeMappingInput) -> dict[str, object]: + return { + "reference_name": item.reference_name, + "residue_name": item.residue_name, + "bond_tolerances": [ + { + "atom1_name": bond.atom1_name, + "atom2_name": bond.atom2_name, + "tolerance": float(bond.tolerance), + } + for bond in item.bond_tolerances + ], + "tight_pass_scale": float(item.tight_pass_scale), + "relaxed_pass_scale": float(item.relaxed_pass_scale), + "max_assignment_distance": item.max_assignment_distance, + "max_missing_hydrogens": int(item.max_missing_hydrogens), + } + + +def _molecule_input_from_dict( + payload: dict[str, object] +) -> MoleculeMappingInput: + return MoleculeMappingInput( + reference_name=str(payload.get("reference_name", "")).strip(), + residue_name=str(payload.get("residue_name", "")).strip(), + bond_tolerances=tuple( + ReferenceBondToleranceInput( + atom1_name=str(bond.get("atom1_name", "")).strip(), + atom2_name=str(bond.get("atom2_name", "")).strip(), + tolerance=float(bond.get("tolerance", 0.0)), + ) + for bond in _dict_items(payload.get("bond_tolerances")) + ), + tight_pass_scale=float(payload.get("tight_pass_scale", 0.85)), + relaxed_pass_scale=float(payload.get("relaxed_pass_scale", 1.35)), + max_assignment_distance=( + None + if payload.get("max_assignment_distance") is None + else float(payload.get("max_assignment_distance")) + ), + max_missing_hydrogens=max( + int(payload.get("max_missing_hydrogens", 0)), + 0, + ), + ) + + +def _free_atom_input_to_dict(item: FreeAtomMappingInput) -> dict[str, object]: + return { + "element": item.element, + "residue_name": item.residue_name, + } + + +def _free_atom_input_from_dict( + payload: dict[str, object] +) -> FreeAtomMappingInput: + return FreeAtomMappingInput( + element=str(payload.get("element", "")).strip(), + residue_name=str(payload.get("residue_name", "")).strip(), + ) + + +def _pbc_params_from_dict(value: object) -> dict[str, float | str]: + if not isinstance(value, dict): + return {} + parsed: dict[str, float | str] = {} + for key in ("a", "b", "c", "alpha", "beta", "gamma"): + if value.get(key) is not None: + parsed[key] = float(value[key]) + if value.get("space_group") is not None: + parsed["space_group"] = str(value["space_group"]) + return parsed + + +def _dict_items(value: object) -> tuple[dict[str, object], ...]: + if not isinstance(value, list): + return () + return tuple(dict(item) for item in value if isinstance(item, dict)) + + +def _optional_text(value: object) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None + + +def _emit_log( + callback: XYZToPDBRunLogCallback | None, + message: str, +) -> None: + if callback is not None: + callback(str(message).strip()) + + +__all__ = [ + "DEFAULT_RUN_FILE_NAME", + "XYZToPDBRunConfig", + "XYZToPDBRunExecutionSummary", + "build_xyz2pdb_run_config", + "default_xyz2pdb_run_file_path", + "load_xyz2pdb_run_config", + "path_text_for_run_config", + "resolve_run_config_path", + "run_xyz2pdb_run_config", + "save_xyz2pdb_run_config", +] diff --git a/src/saxshell/xyz2pdb/ui/__init__.py b/src/saxshell/xyz2pdb/ui/__init__.py index 1f7d4a4..f8b0389 100644 --- a/src/saxshell/xyz2pdb/ui/__init__.py +++ b/src/saxshell/xyz2pdb/ui/__init__.py @@ -1,5 +1,17 @@ """Qt widgets for the xyz2pdb application.""" +from .batch_queue_window import ( + XYZToPDBBatchQueueWindow, + launch_xyz2pdb_batch_queue_ui, +) from .main_window import XYZToPDBMainWindow, launch_xyz2pdb_ui +from .run_file_window import XYZToPDBRunFileWindow, launch_xyz2pdb_run_file_ui -__all__ = ["XYZToPDBMainWindow", "launch_xyz2pdb_ui"] +__all__ = [ + "XYZToPDBBatchQueueWindow", + "XYZToPDBMainWindow", + "XYZToPDBRunFileWindow", + "launch_xyz2pdb_batch_queue_ui", + "launch_xyz2pdb_run_file_ui", + "launch_xyz2pdb_ui", +] diff --git a/src/saxshell/xyz2pdb/ui/batch_queue_window.py b/src/saxshell/xyz2pdb/ui/batch_queue_window.py new file mode 100644 index 0000000..5402d40 --- /dev/null +++ b/src/saxshell/xyz2pdb/ui/batch_queue_window.py @@ -0,0 +1,1437 @@ +from __future__ import annotations + +import re +import threading +import uuid +from dataclasses import dataclass, replace +from pathlib import Path + +from PySide6.QtCore import QObject, Qt, QThread, Signal, Slot +from PySide6.QtWidgets import ( + QAbstractItemView, + QApplication, + QComboBox, + QFileDialog, + QFormLayout, + QFrame, + QGridLayout, + QGroupBox, + QHBoxLayout, + QHeaderView, + QLabel, + QLineEdit, + QListView, + QListWidget, + QListWidgetItem, + QMainWindow, + QMessageBox, + QProgressBar, + QPushButton, + QSizePolicy, + QTableWidget, + QTableWidgetItem, + QTextEdit, + QToolButton, + QTreeView, + QVBoxLayout, + QWidget, +) + +from saxshell.saxs.project_manager import ( + SAXSProjectManager, + build_project_paths, +) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) +from saxshell.xyz2pdb.mapping_workflow import ( + FreeAtomMappingInput, + MoleculeMappingInput, + XYZToPDBMappingWorkflow, +) +from saxshell.xyz2pdb.workflow import ( + ReferenceLibraryEntry, + XYZToPDBExportResult, + default_reference_library_dir, + list_reference_library, + suggest_output_dir, +) + + +def _new_item_id() -> str: + return uuid.uuid4().hex + + +def _optional_path(text: str) -> Path | None: + stripped = text.strip() + if not stripped: + return None + return Path(stripped).expanduser().resolve() + + +def _required_path(text: str, field_name: str) -> Path: + path = _optional_path(text) + if path is None: + raise ValueError(f"{field_name} is required.") + return path + + +def _required_existing_input_path(text: str) -> Path: + path = _required_path(text, "XYZ input path") + if not path.exists(): + raise ValueError(f"XYZ input path does not exist: {path}") + return path + + +def _required_project_dir(text: str) -> Path: + project_dir = _required_path(text, "Project folder") + project_file = build_project_paths(project_dir).project_file + if not project_file.is_file(): + raise ValueError(f"Project file does not exist: {project_file}") + return project_dir + + +def _validated_residue_code(value: str, field_name: str) -> str: + residue = value.strip().upper() + if not re.fullmatch(r"[A-Z]{3}", residue): + raise ValueError( + f"{field_name} must be exactly three capital letters." + ) + return residue + + +def _default_free_atom_residue(element: str) -> str: + letters = re.sub(r"[^A-Za-z]", "", element).upper() + return (letters + "XX")[:3] + + +def _project_reference_text(project_dir: Path | None) -> str: + if project_dir is None: + return "Project reference: choose a SAXSShell project folder." + project_file = build_project_paths(project_dir).project_file + if project_file.is_file(): + return f"Project reference: {project_file}" + return f"Project reference: no project file found at {project_file}" + + +def _dialog_start_dir(*candidates: str | Path | None) -> str: + for candidate in candidates: + if candidate is None: + continue + path = Path(candidate).expanduser() + if path.is_file(): + return str(path.parent) + if path.is_dir(): + return str(path) + return str(Path.home()) + + +def _choose_existing_directories( + parent: QWidget, + *, + title: str, + start_dir: str | Path, +) -> tuple[Path, ...]: + dialog = QFileDialog(parent, title, str(start_dir)) + dialog.setFileMode(QFileDialog.FileMode.Directory) + dialog.setOption(QFileDialog.Option.ShowDirsOnly, True) + dialog.setOption(QFileDialog.Option.DontUseNativeDialog, True) + for view in dialog.findChildren(QListView) + dialog.findChildren( + QTreeView + ): + view.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + if dialog.exec() != int(QFileDialog.DialogCode.Accepted): + return () + return tuple( + Path(path).expanduser().resolve() for path in dialog.selectedFiles() + ) + + +@dataclass(slots=True, frozen=True) +class XYZToPDBBatchJob: + project_dir: Path + input_path: Path + reference_library_dir: Path + molecule_inputs: tuple[MoleculeMappingInput, ...] + free_atom_inputs: tuple[FreeAtomMappingInput, ...] + hydrogen_mode: str = "leave_unassigned" + + +@dataclass(slots=True) +class XYZToPDBBatchResult: + project_dir: Path + input_path: Path + output_dir: Path + written_count: int + + +@dataclass(slots=True) +class XYZToPDBBatchItem: + item_id: str + project_dir: Path | None = None + input_path: Path | None = None + reference_library_dir: Path = default_reference_library_dir() + molecule_inputs: tuple[MoleculeMappingInput, ...] = () + free_atom_inputs: tuple[FreeAtomMappingInput, ...] = () + hydrogen_mode: str = "leave_unassigned" + + def display_name(self) -> str: + if self.project_dir is not None: + return self.project_dir.name + if self.input_path is not None: + return self.input_path.name + return "New XYZ -> PDB conversion" + + def to_job(self) -> XYZToPDBBatchJob: + project_dir = _required_project_dir( + "" if self.project_dir is None else str(self.project_dir) + ) + input_path = _required_existing_input_path( + "" if self.input_path is None else str(self.input_path) + ) + library_dir = Path(self.reference_library_dir).expanduser().resolve() + if not library_dir.is_dir(): + raise ValueError( + f"Reference library folder does not exist: {library_dir}" + ) + if not self.molecule_inputs and not self.free_atom_inputs: + raise ValueError( + "Add at least one reference molecule or free atom mapping." + ) + return XYZToPDBBatchJob( + project_dir=project_dir, + input_path=input_path, + reference_library_dir=library_dir, + molecule_inputs=tuple(self.molecule_inputs), + free_atom_inputs=tuple(self.free_atom_inputs), + hydrogen_mode=self.hydrogen_mode or "leave_unassigned", + ) + + +def _queue_item_from_project_defaults( + project_dir: str | Path, + *, + item_id: str | None = None, + reference_library_dir: str | Path | None = None, +) -> XYZToPDBBatchItem: + resolved_project_dir = Path(project_dir).expanduser().resolve() + item = XYZToPDBBatchItem( + item_id=item_id or _new_item_id(), + project_dir=resolved_project_dir, + reference_library_dir=( + default_reference_library_dir() + if reference_library_dir is None + else Path(reference_library_dir).expanduser().resolve() + ), + ) + try: + settings = SAXSProjectManager().load_project(resolved_project_dir) + except Exception: + return item + return replace(item, input_path=settings.resolved_frames_dir) + + +class XYZToPDBBatchItemWidget(QFrame): + settings_changed = Signal(str) + remove_requested = Signal(str) + duplicate_requested = Signal(str) + + def __init__( + self, + item: XYZToPDBBatchItem, + *, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._item = item + self._loading = False + self._selected = False + self._available_elements: tuple[str, ...] = () + self._reference_entries: tuple[ReferenceLibraryEntry, ...] = () + self._build_ui() + self._load_item(item) + self._refresh_reference_entries() + self._set_settings_visible(False) + + @property + def item_id(self) -> str: + return self._item.item_id + + def item(self) -> XYZToPDBBatchItem: + return self._item + + def collect_item(self) -> XYZToPDBBatchItem: + self._item = XYZToPDBBatchItem( + item_id=self._item.item_id, + project_dir=_optional_path(self.project_dir_edit.text()), + input_path=_optional_path(self.input_path_edit.text()), + reference_library_dir=( + _optional_path(self.reference_library_edit.text()) + or default_reference_library_dir() + ), + molecule_inputs=tuple(self._molecule_inputs_from_table()), + free_atom_inputs=tuple(self._free_atom_inputs_from_table()), + hydrogen_mode=str( + self.hydrogen_mode_combo.currentData() or "leave_unassigned" + ), + ) + self._refresh_header() + self._refresh_project_reference() + return self._item + + def job(self) -> XYZToPDBBatchJob: + return self.collect_item().to_job() + + def set_locked(self, locked: bool) -> None: + self.settings_group.setEnabled(not locked) + self.analyze_button.setEnabled(not locked) + self.duplicate_button.setEnabled(not locked) + self.remove_button.setEnabled(not locked) + + def set_status(self, message: str) -> None: + self.status_label.setText(message) + + def set_progress(self, processed: int, total: int) -> None: + self.progress_bar.setRange(0, max(int(total), 1)) + self.progress_bar.setValue(max(int(processed), 0)) + + def set_selected(self, selected: bool) -> None: + self._selected = bool(selected) + self.header_frame.setProperty("selected", self._selected) + self.header_frame.setStyleSheet( + "QFrame#XYZToPDBBatchItemHeader {" + + ( + "background-color: #dce8f7; " "border: 1px solid #8fb0d7;" + if self._selected + else "background-color: #f6f8fb; " "border: 1px solid #cfd7e3;" + ) + + "border-radius: 5px;}" + ) + + def analyze_input(self) -> None: + input_path = _required_existing_input_path(self.input_path_edit.text()) + library_dir = ( + _optional_path(self.reference_library_edit.text()) + or default_reference_library_dir() + ) + workflow = XYZToPDBMappingWorkflow( + input_path, + reference_library_dir=library_dir, + ) + analysis = workflow.analyze_input() + self._available_elements = tuple(sorted(analysis.element_counts)) + self._refresh_free_element_combo() + self._refresh_reference_entries() + lines = [ + f"XYZ files: {analysis.inspection.total_files}", + f"Sample frame: {analysis.sample_file.name}", + "Elements: " + + ", ".join( + f"{element} x{count}" + for element, count in sorted(analysis.element_counts.items()) + ), + f"Suggested PDB folder: {suggest_output_dir(input_path)}", + ] + self.analysis_summary_label.setText("\n".join(lines)) + self.set_progress(0, max(analysis.inspection.total_files, 1)) + self.set_status("Input analyzed") + + def _build_ui(self) -> None: + self.setFrameShape(QFrame.Shape.StyledPanel) + self.setSizePolicy( + QSizePolicy.Policy.Expanding, + QSizePolicy.Policy.Fixed, + ) + root = QVBoxLayout(self) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + self.header_frame = QFrame() + self.header_frame.setObjectName("XYZToPDBBatchItemHeader") + header = QHBoxLayout(self.header_frame) + header.setContentsMargins(8, 6, 8, 6) + header.setSpacing(8) + self.toggle_button = QToolButton() + self.toggle_button.setCheckable(True) + self.toggle_button.toggled.connect(self._set_settings_visible) + header.addWidget(self.toggle_button) + self.title_label = QLabel("New XYZ -> PDB conversion") + self.title_label.setStyleSheet("font-weight: 600;") + header.addWidget(self.title_label, stretch=1) + self.status_label = QLabel("Ready") + self.status_label.setMinimumWidth(180) + header.addWidget(self.status_label) + self.analyze_button = QPushButton("Analyze") + self.analyze_button.clicked.connect(self._analyze_from_button) + header.addWidget(self.analyze_button) + self.duplicate_button = QPushButton("Duplicate") + self.duplicate_button.clicked.connect( + lambda: self.duplicate_requested.emit(self.item_id) + ) + header.addWidget(self.duplicate_button) + self.remove_button = QPushButton("Remove") + self.remove_button.clicked.connect( + lambda: self.remove_requested.emit(self.item_id) + ) + header.addWidget(self.remove_button) + root.addWidget(self.header_frame) + self.set_selected(False) + + self.progress_bar = QProgressBar() + self.progress_bar.setRange(0, 1) + self.progress_bar.setValue(0) + self.progress_bar.setFormat("%v / %m steps") + root.addWidget(self.progress_bar) + + self.settings_group = QGroupBox("XYZ -> PDB Conversion Settings") + root.addWidget(self.settings_group) + form = QFormLayout(self.settings_group) + + project_row = QWidget() + project_layout = QHBoxLayout(project_row) + project_layout.setContentsMargins(0, 0, 0, 0) + self.project_dir_edit = QLineEdit() + self.project_dir_edit.editingFinished.connect(self._on_editor_changed) + project_layout.addWidget(self.project_dir_edit, stretch=1) + project_button = QPushButton("Browse...") + project_button.clicked.connect(self._choose_project_dir) + project_layout.addWidget(project_button) + form.addRow("Project folder", project_row) + + self.project_reference_label = QLabel() + self.project_reference_label.setWordWrap(True) + self.project_reference_label.setFrameShape(QFrame.Shape.StyledPanel) + form.addRow("", self.project_reference_label) + + input_row = QWidget() + input_layout = QHBoxLayout(input_row) + input_layout.setContentsMargins(0, 0, 0, 0) + self.input_path_edit = QLineEdit() + self.input_path_edit.editingFinished.connect(self._on_editor_changed) + input_layout.addWidget(self.input_path_edit, stretch=1) + input_folder_button = QPushButton("Folder...") + input_folder_button.clicked.connect(self._choose_input_dir) + input_layout.addWidget(input_folder_button) + input_file_button = QPushButton("File...") + input_file_button.clicked.connect(self._choose_input_file) + input_layout.addWidget(input_file_button) + form.addRow("XYZ input", input_row) + + library_row = QWidget() + library_layout = QHBoxLayout(library_row) + library_layout.setContentsMargins(0, 0, 0, 0) + self.reference_library_edit = QLineEdit() + self.reference_library_edit.editingFinished.connect( + self._on_reference_library_changed + ) + library_layout.addWidget(self.reference_library_edit, stretch=1) + library_button = QPushButton("Browse...") + library_button.clicked.connect(self._choose_reference_library_dir) + library_layout.addWidget(library_button) + form.addRow("Reference library", library_row) + + self.analysis_summary_label = QLabel( + "Analyze the XYZ input to populate the free-atom element list." + ) + self.analysis_summary_label.setWordWrap(True) + self.analysis_summary_label.setFrameShape(QFrame.Shape.StyledPanel) + form.addRow("", self.analysis_summary_label) + + form.addRow("", self._build_free_atoms_group()) + form.addRow("", self._build_reference_molecules_group()) + + self.hydrogen_mode_combo = QComboBox() + self.hydrogen_mode_combo.addItem( + "Leave unassigned", + "leave_unassigned", + ) + self.hydrogen_mode_combo.addItem( + "Assign orphaned hydrogen", + "assign_orphaned", + ) + self.hydrogen_mode_combo.addItem( + "Restore missing hydrogen", + "restore_missing", + ) + self.hydrogen_mode_combo.currentIndexChanged.connect( + self._on_editor_changed + ) + form.addRow("Hydrogen handling", self.hydrogen_mode_combo) + + def _build_free_atoms_group(self) -> QGroupBox: + group = QGroupBox("Free Atoms") + layout = QVBoxLayout(group) + controls = QGridLayout() + self.free_element_combo = QComboBox() + self.free_element_combo.currentIndexChanged.connect( + self._on_free_element_changed + ) + self.free_residue_edit = QLineEdit() + self.free_residue_edit.setPlaceholderText("SOL") + add_button = QPushButton("Add Free Atom") + add_button.clicked.connect(self._add_free_atom) + remove_button = QPushButton("Remove Selected") + remove_button.clicked.connect(self._remove_selected_free_atom) + controls.addWidget(QLabel("Element"), 0, 0) + controls.addWidget(self.free_element_combo, 0, 1) + controls.addWidget(QLabel("Residue"), 0, 2) + controls.addWidget(self.free_residue_edit, 0, 3) + controls.addWidget(add_button, 0, 4) + controls.addWidget(remove_button, 0, 5) + controls.setColumnStretch(1, 1) + controls.setColumnStretch(3, 1) + layout.addLayout(controls) + + self.free_atom_table = QTableWidget(0, 2) + self.free_atom_table.setHorizontalHeaderLabels(["Element", "Residue"]) + self.free_atom_table.verticalHeader().setVisible(False) + self.free_atom_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.free_atom_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.free_atom_table.setEditTriggers( + QAbstractItemView.EditTrigger.NoEditTriggers + ) + header = self.free_atom_table.horizontalHeader() + header.setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + self.free_atom_table.setMinimumHeight(120) + layout.addWidget(self.free_atom_table) + return group + + def _build_reference_molecules_group(self) -> QGroupBox: + group = QGroupBox("Reference Molecules") + layout = QVBoxLayout(group) + controls = QGridLayout() + self.reference_combo = QComboBox() + self.reference_combo.currentIndexChanged.connect( + self._on_reference_selection_changed + ) + self.molecule_residue_edit = QLineEdit() + self.molecule_residue_edit.setPlaceholderText("DMF") + add_button = QPushButton("Add Molecule") + add_button.clicked.connect(self._add_molecule) + remove_button = QPushButton("Remove Selected") + remove_button.clicked.connect(self._remove_selected_molecule) + controls.addWidget(QLabel("Reference"), 0, 0) + controls.addWidget(self.reference_combo, 0, 1) + controls.addWidget(QLabel("Residue"), 0, 2) + controls.addWidget(self.molecule_residue_edit, 0, 3) + controls.addWidget(add_button, 0, 4) + controls.addWidget(remove_button, 0, 5) + controls.setColumnStretch(1, 1) + controls.setColumnStretch(3, 1) + layout.addLayout(controls) + + self.molecule_table = QTableWidget(0, 2) + self.molecule_table.setHorizontalHeaderLabels(["Reference", "Residue"]) + self.molecule_table.verticalHeader().setVisible(False) + self.molecule_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.molecule_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.molecule_table.setEditTriggers( + QAbstractItemView.EditTrigger.NoEditTriggers + ) + header = self.molecule_table.horizontalHeader() + header.setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + self.molecule_table.setMinimumHeight(140) + layout.addWidget(self.molecule_table) + return group + + def _load_item(self, item: XYZToPDBBatchItem) -> None: + self._loading = True + self.project_dir_edit.setText( + "" if item.project_dir is None else str(item.project_dir) + ) + self.input_path_edit.setText( + "" if item.input_path is None else str(item.input_path) + ) + self.reference_library_edit.setText(str(item.reference_library_dir)) + self._set_free_atom_inputs(item.free_atom_inputs) + self._set_molecule_inputs(item.molecule_inputs) + self._set_combo_value(self.hydrogen_mode_combo, item.hydrogen_mode) + self._loading = False + self._refresh_header() + self._refresh_project_reference() + + def _set_free_atom_inputs( + self, + inputs: tuple[FreeAtomMappingInput, ...], + ) -> None: + self.free_atom_table.setRowCount(0) + for item in inputs: + row = self.free_atom_table.rowCount() + self.free_atom_table.insertRow(row) + self.free_atom_table.setItem( + row, + 0, + self._readonly_table_item(item.element), + ) + self.free_atom_table.setItem( + row, + 1, + self._readonly_table_item(item.residue_name), + ) + + def _set_molecule_inputs( + self, + inputs: tuple[MoleculeMappingInput, ...], + ) -> None: + self.molecule_table.setRowCount(0) + for item in inputs: + row = self.molecule_table.rowCount() + self.molecule_table.insertRow(row) + self.molecule_table.setItem( + row, + 0, + self._readonly_table_item(item.reference_name), + ) + self.molecule_table.setItem( + row, + 1, + self._readonly_table_item(item.residue_name), + ) + + def _set_settings_visible(self, visible: bool) -> None: + self.settings_group.setVisible(bool(visible)) + self.toggle_button.setChecked(bool(visible)) + self.toggle_button.setText("Hide Settings" if visible else "Settings") + parent_item = self._list_item() + if parent_item is not None: + parent_item.setSizeHint(self.sizeHint()) + + def _list_item(self) -> QListWidgetItem | None: + parent = self.parent() + while parent is not None and not isinstance(parent, QListWidget): + parent = parent.parent() + if not isinstance(parent, QListWidget): + return None + for row in range(parent.count()): + list_item = parent.item(row) + if parent.itemWidget(list_item) is self: + return list_item + return None + + def _choose_project_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select SAXSShell project folder", + _dialog_start_dir(self.project_dir_edit.text()), + ) + if not selected: + return + current_library = ( + _optional_path(self.reference_library_edit.text()) + or default_reference_library_dir() + ) + self._load_item( + replace( + _queue_item_from_project_defaults( + selected, + item_id=self.item_id, + reference_library_dir=current_library, + ), + molecule_inputs=tuple(self._molecule_inputs_from_table()), + free_atom_inputs=tuple(self._free_atom_inputs_from_table()), + hydrogen_mode=str( + self.hydrogen_mode_combo.currentData() + or "leave_unassigned" + ), + ) + ) + self._on_editor_changed() + self._analyze_quietly() + + def _choose_input_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select XYZ frames folder", + _dialog_start_dir( + self.input_path_edit.text(), + self.project_dir_edit.text(), + ), + ) + if not selected: + return + self.input_path_edit.setText(selected) + self._on_editor_changed() + self._analyze_quietly() + + def _choose_input_file(self) -> None: + selected, _filter = QFileDialog.getOpenFileName( + self, + "Select XYZ frame file", + _dialog_start_dir( + self.input_path_edit.text(), + self.project_dir_edit.text(), + ), + "XYZ files (*.xyz);;All files (*)", + ) + if not selected: + return + self.input_path_edit.setText(selected) + self._on_editor_changed() + self._analyze_quietly() + + def _choose_reference_library_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select reference library folder", + _dialog_start_dir(self.reference_library_edit.text()), + ) + if not selected: + return + self.reference_library_edit.setText(selected) + self._on_reference_library_changed() + + def _analyze_from_button(self) -> None: + try: + self.analyze_input() + self._on_editor_changed() + except Exception as exc: + QMessageBox.warning(self, "Unable to analyze XYZ input", str(exc)) + self.analysis_summary_label.setText(str(exc)) + self.set_status("Analysis failed") + self._on_editor_changed() + + def _analyze_quietly(self) -> None: + if not self.input_path_edit.text().strip(): + return + try: + self.analyze_input() + except Exception as exc: + self.analysis_summary_label.setText(str(exc)) + self.set_status("Analysis failed") + + def _refresh_reference_entries(self) -> None: + library_dir = ( + _optional_path(self.reference_library_edit.text()) + or default_reference_library_dir() + ) + try: + entries = tuple(list_reference_library(library_dir)) + except Exception: + entries = () + self._reference_entries = entries + current = self.reference_combo.currentData() + self.reference_combo.blockSignals(True) + self.reference_combo.clear() + for entry in entries: + self.reference_combo.addItem(entry.name, entry.name) + if current is not None: + index = self.reference_combo.findData(current) + if index >= 0: + self.reference_combo.setCurrentIndex(index) + self.reference_combo.blockSignals(False) + self._apply_selected_reference_default_residue() + + def _refresh_free_element_combo(self) -> None: + current = self.free_element_combo.currentData() + self.free_element_combo.blockSignals(True) + self.free_element_combo.clear() + for element in self._available_elements: + self.free_element_combo.addItem(element, element) + if current is not None: + index = self.free_element_combo.findData(current) + if index >= 0: + self.free_element_combo.setCurrentIndex(index) + self.free_element_combo.blockSignals(False) + self._on_free_element_changed() + + def _on_reference_library_changed(self) -> None: + self._refresh_reference_entries() + self._on_editor_changed() + + def _on_free_element_changed(self, *_args) -> None: + element = str(self.free_element_combo.currentData() or "").strip() + if element and not self.free_residue_edit.text().strip(): + self.free_residue_edit.setText(_default_free_atom_residue(element)) + + def _on_reference_selection_changed(self, *_args) -> None: + self._apply_selected_reference_default_residue() + + def _apply_selected_reference_default_residue(self) -> None: + reference_name = str(self.reference_combo.currentData() or "").strip() + if not reference_name: + return + entry = next( + ( + entry + for entry in self._reference_entries + if entry.name == reference_name + ), + None, + ) + if entry is None: + return + if not self.molecule_residue_edit.text().strip(): + self.molecule_residue_edit.setText(entry.residue_name) + + def _add_free_atom(self) -> None: + try: + element = str(self.free_element_combo.currentData() or "").strip() + if not element: + raise ValueError("Choose an element to add as a free atom.") + residue = _validated_residue_code( + self.free_residue_edit.text(), + "Free-atom residue", + ) + for row in range(self.free_atom_table.rowCount()): + item = self.free_atom_table.item(row, 0) + if item is not None and item.text().strip() == element: + raise ValueError( + f"{element} is already listed as a free atom." + ) + except Exception as exc: + QMessageBox.warning(self, "Unable to add free atom", str(exc)) + return + row = self.free_atom_table.rowCount() + self.free_atom_table.insertRow(row) + self.free_atom_table.setItem( + row, + 0, + self._readonly_table_item(element), + ) + self.free_atom_table.setItem( + row, + 1, + self._readonly_table_item(residue), + ) + self._on_editor_changed() + + def _remove_selected_free_atom(self) -> None: + row = self.free_atom_table.currentRow() + if row < 0: + return + self.free_atom_table.removeRow(row) + self._on_editor_changed() + + def _add_molecule(self) -> None: + try: + reference_name = str( + self.reference_combo.currentData() or "" + ).strip() + if not reference_name: + raise ValueError("Choose a reference molecule first.") + residue = _validated_residue_code( + self.molecule_residue_edit.text(), + "Reference-molecule residue", + ) + for row in range(self.molecule_table.rowCount()): + item = self.molecule_table.item(row, 1) + if item is not None and item.text().strip() == residue: + raise ValueError(f"Residue {residue} is already listed.") + except Exception as exc: + QMessageBox.warning(self, "Unable to add molecule", str(exc)) + return + row = self.molecule_table.rowCount() + self.molecule_table.insertRow(row) + self.molecule_table.setItem( + row, + 0, + self._readonly_table_item(reference_name), + ) + self.molecule_table.setItem( + row, + 1, + self._readonly_table_item(residue), + ) + self._on_editor_changed() + + def _remove_selected_molecule(self) -> None: + row = self.molecule_table.currentRow() + if row < 0: + return + self.molecule_table.removeRow(row) + self._on_editor_changed() + + def _free_atom_inputs_from_table(self) -> list[FreeAtomMappingInput]: + inputs: list[FreeAtomMappingInput] = [] + for row in range(self.free_atom_table.rowCount()): + element_item = self.free_atom_table.item(row, 0) + residue_item = self.free_atom_table.item(row, 1) + if element_item is None or residue_item is None: + continue + inputs.append( + FreeAtomMappingInput( + element=element_item.text().strip(), + residue_name=residue_item.text().strip(), + ) + ) + return inputs + + def _molecule_inputs_from_table(self) -> list[MoleculeMappingInput]: + inputs: list[MoleculeMappingInput] = [] + for row in range(self.molecule_table.rowCount()): + reference_item = self.molecule_table.item(row, 0) + residue_item = self.molecule_table.item(row, 1) + if reference_item is None or residue_item is None: + continue + inputs.append( + MoleculeMappingInput( + reference_name=reference_item.text().strip(), + residue_name=residue_item.text().strip(), + ) + ) + return inputs + + def _on_editor_changed(self, *_args) -> None: + if self._loading: + return + try: + self.collect_item() + if self.status_label.text() in {"Analysis failed", "Failed"}: + self.set_status("Ready") + except Exception: + self._refresh_header() + self._refresh_project_reference() + self.settings_changed.emit(self.item_id) + + def _refresh_header(self) -> None: + self.title_label.setText(self._item.display_name()) + + def _refresh_project_reference(self) -> None: + project_dir = _optional_path(self.project_dir_edit.text()) + self.project_reference_label.setText( + _project_reference_text(project_dir) + ) + + @staticmethod + def _readonly_table_item(value: object) -> QTableWidgetItem: + item = QTableWidgetItem(str(value)) + item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable) + return item + + @staticmethod + def _set_combo_value(combo: QComboBox, value: str) -> None: + index = combo.findData(value) + if index < 0: + index = combo.findText(value) + if index >= 0: + combo.setCurrentIndex(index) + + +class XYZToPDBBatchWorker(QObject): + item_started = Signal(str, int, int) + item_progress = Signal(str, int, int, str) + item_finished = Signal(str, object) + item_failed = Signal(str, str) + log = Signal(str) + status = Signal(str) + finished = Signal(object) + failed = Signal(str, str) + + def __init__( + self, + queue_entries: list[tuple[str, XYZToPDBBatchJob]], + ) -> None: + super().__init__() + self.queue_entries = list(queue_entries) + self._cancel_requested = threading.Event() + self._project_manager = SAXSProjectManager() + + def request_cancel(self) -> None: + self._cancel_requested.set() + + @Slot() + def run(self) -> None: + results: list[XYZToPDBBatchResult] = [] + total_items = len(self.queue_entries) + for index, (item_id, job) in enumerate( + self.queue_entries, + start=1, + ): + if self._cancel_requested.is_set(): + self.log.emit("Batch queue stopped before the next project.") + break + self.item_started.emit(item_id, index, total_items) + self.status.emit( + f"Running {index}/{total_items}: {job.project_dir.name}" + ) + self.log.emit(f"Starting {index}/{total_items}: {job.project_dir}") + try: + result = self._run_job(item_id, job) + except Exception as exc: + message = str(exc) + self.item_failed.emit(item_id, message) + self.failed.emit(item_id, message) + return + results.append(result) + self.item_finished.emit(item_id, result) + self.status.emit("XYZ -> PDB batch queue finished") + self.finished.emit(results) + + def _run_job( + self, + item_id: str, + job: XYZToPDBBatchJob, + ) -> XYZToPDBBatchResult: + settings = self._project_manager.load_project(job.project_dir) + workflow = XYZToPDBMappingWorkflow( + job.input_path, + reference_library_dir=job.reference_library_dir, + ) + self.item_progress.emit( + item_id, + 0, + 1, + "Preparing mapping", + ) + + def on_progress( + processed: int, + total: int, + message: str, + ) -> None: + self.item_progress.emit(item_id, processed, total, message) + + result: XYZToPDBExportResult = workflow.export_with_mapping( + molecule_inputs=job.molecule_inputs, + free_atom_inputs=job.free_atom_inputs, + hydrogen_mode=job.hydrogen_mode, + progress_callback=on_progress, + log_callback=( + lambda message: self.log.emit( + f"[{job.project_dir.name}] {message}" + ) + ), + cancel_callback=self._cancel_requested.is_set, + ) + settings.pdb_frames_dir = str(result.output_dir.expanduser().resolve()) + self._project_manager.save_project(settings) + self.log.emit( + f"[{job.project_dir.name}] Registered PDB frames folder: " + f"{settings.pdb_frames_dir}" + ) + return XYZToPDBBatchResult( + project_dir=job.project_dir, + input_path=job.input_path, + output_dir=result.output_dir.expanduser().resolve(), + written_count=len(result.written_files), + ) + + +class XYZToPDBBatchQueueWindow(QMainWindow): + """Queue XYZ-to-PDB conversions for multiple projects.""" + + project_paths_registered = Signal(object) + + def __init__( + self, + initial_project_dir: str | Path | None = None, + *, + initial_input_path: str | Path | None = None, + reference_library_dir: str | Path | None = None, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._widgets_by_id: dict[str, XYZToPDBBatchItemWidget] = {} + self._run_thread: QThread | None = None + self._run_worker: XYZToPDBBatchWorker | None = None + self._initial_project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + self._initial_input_path = ( + None + if initial_input_path is None + else Path(initial_input_path).expanduser().resolve() + ) + self._reference_library_dir = ( + default_reference_library_dir() + if reference_library_dir is None + else Path(reference_library_dir).expanduser().resolve() + ) + self._build_ui() + if ( + self._initial_project_dir is not None + or self._initial_input_path is not None + ): + self._add_current_project() + + def closeEvent(self, event) -> None: + if self._run_thread is not None and self._run_thread.isRunning(): + self._request_cancel() + self.hide() + while ( + self._run_thread is not None and self._run_thread.isRunning() + ): + QApplication.processEvents() + if self._run_thread is not None: + self._run_thread.wait(50) + event.accept() + return + super().closeEvent(event) + + def add_queue_item( + self, + item: XYZToPDBBatchItem | None = None, + *, + auto_analyze: bool = False, + ) -> XYZToPDBBatchItemWidget: + resolved_item = item or XYZToPDBBatchItem(item_id=_new_item_id()) + list_item = QListWidgetItem() + list_item.setData(Qt.ItemDataRole.UserRole, resolved_item.item_id) + self.queue_list.addItem(list_item) + widget = XYZToPDBBatchItemWidget( + resolved_item, + parent=self.queue_list, + ) + widget.settings_changed.connect(self._on_item_settings_changed) + widget.remove_requested.connect(self._remove_item) + widget.duplicate_requested.connect(self._duplicate_item) + self._widgets_by_id[resolved_item.item_id] = widget + list_item.setSizeHint(widget.sizeHint()) + self.queue_list.setItemWidget(list_item, widget) + self.queue_list.setCurrentItem(list_item) + self._refresh_order_labels() + if auto_analyze: + widget._analyze_quietly() + return widget + + def queue_jobs_in_order(self) -> list[tuple[str, XYZToPDBBatchJob]]: + entries: list[tuple[str, XYZToPDBBatchJob]] = [] + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id[item_id] + entries.append((item_id, widget.job())) + return entries + + def _build_ui(self) -> None: + self.setWindowTitle("SAXSShell XYZ -> PDB Batch Queue") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1120, 860) + + central = QWidget() + root = QVBoxLayout(central) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + controls = QHBoxLayout() + self.add_current_button = QPushButton("Add Current Project") + self.add_current_button.clicked.connect(self._add_current_project) + controls.addWidget(self.add_current_button) + self.add_project_button = QPushButton("Add Projects...") + self.add_project_button.clicked.connect(self._choose_projects_to_add) + controls.addWidget(self.add_project_button) + controls.addStretch(1) + root.addLayout(controls) + + self.queue_list = QListWidget() + self.queue_list.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.queue_list.setDragDropMode( + QAbstractItemView.DragDropMode.InternalMove + ) + self.queue_list.setDefaultDropAction(Qt.DropAction.MoveAction) + self.queue_list.setAlternatingRowColors(True) + self.queue_list.setStyleSheet( + "QListWidget::item:selected { background: transparent; }" + "QListWidget::item:hover { background: transparent; }" + "QListWidget::item { margin: 3px; }" + ) + self.queue_list.model().rowsMoved.connect(self._refresh_order_labels) + self.queue_list.itemSelectionChanged.connect( + self._refresh_item_selection_styles + ) + root.addWidget(self.queue_list, stretch=1) + + run_group = QGroupBox("Execute Queue") + run_layout = QVBoxLayout(run_group) + run_buttons = QHBoxLayout() + self.run_button = QPushButton("Run Complete Queue") + self.run_button.clicked.connect(self._start_queue) + run_buttons.addWidget(self.run_button) + self.cancel_button = QPushButton("Stop Queue") + self.cancel_button.setEnabled(False) + self.cancel_button.clicked.connect(self._request_cancel) + run_buttons.addWidget(self.cancel_button) + run_buttons.addStretch(1) + run_layout.addLayout(run_buttons) + self.queue_status_label = QLabel("Queue idle") + run_layout.addWidget(self.queue_status_label) + self.console = QTextEdit() + self.console.setReadOnly(True) + self.console.setMinimumHeight(160) + run_layout.addWidget(self.console) + root.addWidget(run_group) + + self.setCentralWidget(central) + self.statusBar().showMessage("Ready") + + def _add_current_project(self) -> None: + if ( + self._initial_project_dir is None + and self._initial_input_path is None + ): + QMessageBox.information( + self, + "No active project", + "The main UI did not provide an active project reference.", + ) + return + item = ( + _queue_item_from_project_defaults( + self._initial_project_dir, + reference_library_dir=self._reference_library_dir, + ) + if self._initial_project_dir is not None + else XYZToPDBBatchItem( + item_id=_new_item_id(), + reference_library_dir=self._reference_library_dir, + ) + ) + item = replace( + item, + input_path=self._initial_input_path or item.input_path, + ) + self.add_queue_item(item, auto_analyze=item.input_path is not None) + + def _choose_projects_to_add(self) -> None: + selected_dirs = _choose_existing_directories( + self, + title="Select SAXSShell project folders", + start_dir=self._initial_project_dir or Path.home(), + ) + if not selected_dirs: + return + for project_dir in selected_dirs: + item = _queue_item_from_project_defaults( + project_dir, + reference_library_dir=self._reference_library_dir, + ) + self.add_queue_item(item, auto_analyze=item.input_path is not None) + + def _on_item_settings_changed(self, _item_id: str) -> None: + self._refresh_order_labels() + + def _refresh_order_labels(self, *_args) -> None: + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id.get(item_id) + if widget is None: + continue + widget.title_label.setText( + f"{row + 1}. {widget.item().display_name()}" + ) + list_item.setSizeHint(widget.sizeHint()) + self._refresh_item_selection_styles() + + def _refresh_item_selection_styles(self) -> None: + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_selected(list_item.isSelected()) + + def _remove_item(self, item_id: str) -> None: + if self._run_thread is not None and self._run_thread.isRunning(): + return + for row in range(self.queue_list.count()): + list_item = self.queue_list.item(row) + if str(list_item.data(Qt.ItemDataRole.UserRole)) == item_id: + self.queue_list.takeItem(row) + break + self._widgets_by_id.pop(item_id, None) + self._refresh_order_labels() + + def _duplicate_item(self, item_id: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is None: + return + try: + item = widget.collect_item() + except Exception: + item = widget.item() + self.add_queue_item(replace(item, item_id=_new_item_id())) + + def _set_running(self, running: bool) -> None: + self.add_current_button.setEnabled(not running) + self.add_project_button.setEnabled(not running) + self.run_button.setEnabled(not running) + self.cancel_button.setEnabled(running) + self.queue_list.setDragEnabled(not running) + self.queue_list.setAcceptDrops(not running) + for widget in self._widgets_by_id.values(): + widget.set_locked(running) + + def _start_queue(self) -> None: + if self.queue_list.count() == 0: + QMessageBox.information( + self, + "XYZ -> PDB batch queue", + "Add at least one project before running the queue.", + ) + return + try: + entries = self.queue_jobs_in_order() + except Exception as exc: + QMessageBox.warning( + self, + "Invalid XYZ -> PDB batch settings", + str(exc), + ) + return + + self.console.clear() + self._set_running(True) + self.queue_status_label.setText( + f"Running 0/{len(entries)} queued conversion(s)" + ) + for widget in self._widgets_by_id.values(): + widget.set_progress(0, 1) + widget.set_status("Queued") + + self._run_thread = QThread(self) + self._run_worker = XYZToPDBBatchWorker(entries) + self._run_worker.moveToThread(self._run_thread) + self._run_thread.started.connect(self._run_worker.run) + self._run_worker.item_started.connect(self._on_item_started) + self._run_worker.item_progress.connect(self._on_item_progress) + self._run_worker.item_finished.connect(self._on_item_finished) + self._run_worker.item_failed.connect(self._on_item_failed) + self._run_worker.log.connect(self._append_log) + self._run_worker.status.connect(self._on_status) + self._run_worker.finished.connect(self._on_queue_finished) + self._run_worker.failed.connect(self._on_queue_failed) + self._run_worker.finished.connect(self._run_thread.quit) + self._run_worker.failed.connect(self._run_thread.quit) + self._run_thread.finished.connect(self._cleanup_run_thread) + self._run_thread.finished.connect(self._run_thread.deleteLater) + self._run_thread.start() + + def _request_cancel(self) -> None: + self.cancel_button.setEnabled(False) + self.queue_status_label.setText( + "Stopping queue after the active project finishes" + ) + self._append_log( + "Stop requested; the current project will finish before the " + "queue exits." + ) + if self._run_worker is not None: + self._run_worker.request_cancel() + + def _append_log(self, message: str) -> None: + self.console.append(message) + + def _on_status(self, message: str) -> None: + self.statusBar().showMessage(message) + self.queue_status_label.setText(message) + + def _on_item_started( + self, + item_id: str, + index: int, + total: int, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status(f"Running {index}/{total}") + widget.set_progress(0, 1) + self.queue_status_label.setText( + f"Running {index}/{total} queued conversion(s)" + ) + + def _on_item_progress( + self, + item_id: str, + processed: int, + total: int, + message: str, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_progress(processed, total) + widget.set_status(message) + + def _on_item_finished( + self, + item_id: str, + result: XYZToPDBBatchResult, + ) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is None: + return + widget.set_progress(result.written_count, max(result.written_count, 1)) + widget.set_status("Complete") + self.project_paths_registered.emit( + { + "project_dir": result.project_dir, + "pdb_frames_dir": result.output_dir, + } + ) + + def _on_item_failed(self, item_id: str, message: str) -> None: + widget = self._widgets_by_id.get(item_id) + if widget is not None: + widget.set_status("Failed") + self._append_log(message) + + def _on_queue_finished(self, results: object) -> None: + self._set_running(False) + result_count = len(results) if isinstance(results, list) else 0 + self.queue_status_label.setText( + f"Queue finished: {result_count} conversion(s) saved" + ) + self.statusBar().showMessage("XYZ -> PDB batch queue finished") + + def _on_queue_failed(self, item_id: str, message: str) -> None: + self._set_running(False) + self.queue_status_label.setText("Queue stopped after a failure") + self.statusBar().showMessage("XYZ -> PDB batch queue failed", 5000) + QMessageBox.warning( + self, + "XYZ -> PDB batch queue failed", + f"Queue item {item_id} failed:\n{message}", + ) + + def _cleanup_run_thread(self) -> None: + self._run_thread = None + self._run_worker = None + + +def launch_xyz2pdb_batch_queue_ui( + initial_project_dir: str | Path | None = None, + *, + initial_input_path: str | Path | None = None, + reference_library_dir: str | Path | None = None, +) -> int: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication([]) + configure_saxshell_application(app) + window = XYZToPDBBatchQueueWindow( + initial_project_dir=initial_project_dir, + initial_input_path=initial_input_path, + reference_library_dir=reference_library_dir, + ) + window.show() + return int(app.exec()) + + +__all__ = [ + "XYZToPDBBatchItem", + "XYZToPDBBatchItemWidget", + "XYZToPDBBatchJob", + "XYZToPDBBatchQueueWindow", + "XYZToPDBBatchResult", + "XYZToPDBBatchWorker", + "launch_xyz2pdb_batch_queue_ui", +] diff --git a/src/saxshell/xyz2pdb/ui/run_file_window.py b/src/saxshell/xyz2pdb/ui/run_file_window.py new file mode 100644 index 0000000..b8259aa --- /dev/null +++ b/src/saxshell/xyz2pdb/ui/run_file_window.py @@ -0,0 +1,481 @@ +from __future__ import annotations + +import json +import sys +from pathlib import Path + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QApplication, + QCheckBox, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QMainWindow, + QMessageBox, + QPlainTextEdit, + QPushButton, + QScrollArea, + QSpinBox, + QSplitter, + QVBoxLayout, + QWidget, +) + +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) +from saxshell.xyz2pdb.mapping_workflow import ( + XYZToPDBMappingWorkflow, + reference_bond_tolerances, +) +from saxshell.xyz2pdb.run_config import ( + build_xyz2pdb_run_config, + default_xyz2pdb_run_file_path, + save_xyz2pdb_run_config, +) +from saxshell.xyz2pdb.ui.input_panel import XYZToPDBInputPanel +from saxshell.xyz2pdb.ui.mapping_panel import XYZToPDBMappingPanel +from saxshell.xyz2pdb.ui.reference_panel import ReferenceLibraryPanel +from saxshell.xyz2pdb.workflow import ( + list_reference_library, + suggest_output_dir, +) + + +class XYZToPDBRunFileWindow(QMainWindow): + def __init__( + self, + *, + initial_project_dir: str | Path | None = None, + initial_input_path: str | Path | None = None, + ) -> None: + super().__init__() + self._browse_start_dir = Path.home() + self._last_suggested_output_dir: str | None = None + + project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + input_path = ( + None + if initial_input_path is None + else Path(initial_input_path).expanduser().resolve() + ) + if project_dir is not None: + self._browse_start_dir = project_dir + if input_path is None: + input_path = self._project_frames_dir(project_dir) + + self.setWindowTitle("XYZ -> PDB CLI Setup") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1180, 820) + self._build_ui() + + if project_dir is not None: + self.project_dir_edit.setText(str(project_dir)) + self._refresh_run_file_path() + if input_path is not None: + self.input_panel.input_edit.setText(str(input_path)) + self._browse_start_dir = ( + input_path.parent if input_path.is_file() else input_path + ) + self.refresh_reference_library() + self.inspect_input() + self._update_command_preview() + + def _build_ui(self) -> None: + central = QWidget(self) + root = QVBoxLayout(central) + root.setContentsMargins(10, 10, 10, 10) + root.setSpacing(8) + self.setCentralWidget(central) + + splitter = QSplitter(Qt.Orientation.Horizontal, self) + splitter.setChildrenCollapsible(False) + root.addWidget(splitter, stretch=1) + + left_scroll = QScrollArea(self) + left_scroll.setWidgetResizable(True) + left_panel = QWidget() + self.left_layout = QVBoxLayout(left_panel) + self.left_layout.setContentsMargins(10, 10, 10, 10) + self.left_layout.setSpacing(10) + left_scroll.setWidget(left_panel) + + right_scroll = QScrollArea(self) + right_scroll.setWidgetResizable(True) + right_panel = QWidget() + self.right_layout = QVBoxLayout(right_panel) + self.right_layout.setContentsMargins(10, 10, 10, 10) + self.right_layout.setSpacing(10) + right_scroll.setWidget(right_panel) + + splitter.addWidget(left_scroll) + splitter.addWidget(right_scroll) + splitter.setSizes([610, 570]) + + self.input_panel = XYZToPDBInputPanel() + self.reference_panel = ReferenceLibraryPanel() + self.mapping_panel = XYZToPDBMappingPanel() + self.input_panel.inspect_requested.connect(self.inspect_input) + self.input_panel.input_path_changed.connect( + lambda _path: self._refresh_suggested_output_dir() + ) + self.input_panel.settings_changed.connect(self._update_command_preview) + self.reference_panel.refresh_requested.connect( + self.refresh_reference_library + ) + self.reference_panel.library_dir_changed.connect( + lambda _path: self.refresh_reference_library() + ) + self.mapping_panel.settings_changed.connect( + self._update_command_preview + ) + + self.left_layout.addWidget(self._build_project_group()) + self.left_layout.addWidget(self.input_panel) + self.left_layout.addWidget(self.reference_panel) + self.left_layout.addWidget(self.mapping_panel) + self.left_layout.addWidget(self._build_options_group()) + self.left_layout.addWidget(self._build_save_group()) + self.left_layout.addStretch(1) + + self.right_layout.addWidget(self._build_command_group()) + self.right_layout.addStretch(1) + self.statusBar().showMessage("Ready") + + def _build_project_group(self) -> QGroupBox: + group = QGroupBox("Project") + form = QFormLayout(group) + project_row = QHBoxLayout() + self.project_dir_edit = QLineEdit() + self.project_dir_edit.editingFinished.connect( + self._on_project_dir_changed + ) + project_row.addWidget(self.project_dir_edit, stretch=1) + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self._browse_project_dir) + project_row.addWidget(browse_button) + project_widget = QWidget() + project_widget.setLayout(project_row) + form.addRow("Project folder", project_widget) + + self.run_file_edit = QLineEdit() + self.run_file_edit.setReadOnly(True) + form.addRow("Run file", self.run_file_edit) + return group + + def _build_options_group(self) -> QGroupBox: + group = QGroupBox("Run Options") + form = QFormLayout(group) + form.addRow("Output folder", self._output_row()) + + self.selected_solution_spin = QSpinBox() + self.selected_solution_spin.setRange(0, 63) + self.selected_solution_spin.valueChanged.connect( + self._update_command_preview + ) + form.addRow("Solution index", self.selected_solution_spin) + + self.assertion_mode_checkbox = QCheckBox("Assertion mode") + self.assertion_mode_checkbox.toggled.connect( + self._update_command_preview + ) + form.addRow("", self.assertion_mode_checkbox) + + self.pbc_params_edit = QPlainTextEdit() + self.pbc_params_edit.setPlaceholderText( + '{"a": 20.0, "b": 20.0, "c": 20.0}' + ) + self.pbc_params_edit.setMinimumHeight(80) + self.pbc_params_edit.textChanged.connect(self._update_command_preview) + form.addRow("PBC JSON", self.pbc_params_edit) + return group + + def _output_row(self) -> QWidget: + widget = QWidget() + row = QHBoxLayout(widget) + row.setContentsMargins(0, 0, 0, 0) + self.output_dir_edit = QLineEdit() + self.output_dir_edit.textChanged.connect(self._update_command_preview) + row.addWidget(self.output_dir_edit, stretch=1) + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self._browse_output_dir) + row.addWidget(browse_button) + return widget + + def _build_save_group(self) -> QGroupBox: + group = QGroupBox("Save") + layout = QHBoxLayout(group) + inspect_button = QPushButton("Analyze Input") + inspect_button.clicked.connect(self.inspect_input) + layout.addWidget(inspect_button) + save_button = QPushButton("Save Run File") + save_button.clicked.connect(self._save_run_file) + layout.addWidget(save_button) + layout.addStretch(1) + return group + + def _build_command_group(self) -> QGroupBox: + group = QGroupBox("CLI Command") + layout = QVBoxLayout(group) + self.command_box = QPlainTextEdit() + self.command_box.setReadOnly(True) + self.command_box.setMinimumHeight(140) + layout.addWidget(self.command_box) + + layout.addWidget(QLabel("Run File JSON Preview")) + self.json_preview_box = QPlainTextEdit() + self.json_preview_box.setReadOnly(True) + self.json_preview_box.setMinimumHeight(420) + layout.addWidget(self.json_preview_box) + return group + + def _browse_project_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select project folder", + str(self._browse_start_dir), + ) + if not selected: + return + self.project_dir_edit.setText(selected) + self._on_project_dir_changed() + + def _browse_output_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select xyz2pdb output folder", + self.output_dir_edit.text().strip() or str(self._browse_start_dir), + ) + if selected: + self.output_dir_edit.setText(selected) + self._update_command_preview() + + def _on_project_dir_changed(self, *_args: object) -> None: + project_dir = self._project_dir() + if project_dir is None: + return + self._browse_start_dir = project_dir + self._refresh_run_file_path() + if not self.input_panel.input_edit.text().strip(): + input_path = self._project_frames_dir(project_dir) + if input_path is not None and input_path.exists(): + self.input_panel.input_edit.setText(str(input_path)) + self._update_command_preview() + + def refresh_reference_library(self, *_args: object) -> None: + try: + library_dir = self.reference_panel.get_library_dir() + entries = list_reference_library(library_dir) + self.reference_panel.set_reference_entries(entries) + self.mapping_panel.set_reference_entries( + entries, + bond_defaults_by_name={ + entry.name: reference_bond_tolerances( + entry.name, + library_dir=library_dir, + ) + for entry in entries + }, + ) + self.statusBar().showMessage("Reference library refreshed") + except Exception as exc: + self.statusBar().showMessage("Reference refresh failed") + self.json_preview_box.setPlainText(str(exc)) + self._update_command_preview() + + def inspect_input(self, *_args: object) -> None: + input_path = self.input_panel.get_input_path() + if input_path is None: + self.input_panel.set_summary_text("No XYZ input selected.") + self.input_panel.set_input_mode(None) + self._update_command_preview() + return + try: + workflow = XYZToPDBMappingWorkflow( + input_path, + reference_library_dir=self.reference_panel.get_library_dir(), + output_dir=self._output_dir(), + ) + analysis = workflow.analyze_input() + self.input_panel.set_input_mode(analysis.inspection.input_mode) + self.input_panel.set_summary_text( + "\n".join( + [ + f"Input path: {analysis.inspection.input_path}", + f"XYZ files found: {analysis.inspection.total_files}", + f"Sample frame: {analysis.sample_file.name}", + f"Sample atoms: {analysis.total_atoms}", + "Element counts: " + + ", ".join( + f"{element} x{count}" + for element, count in sorted( + analysis.element_counts.items() + ) + ), + ] + ) + ) + self.mapping_panel.set_available_elements( + tuple(sorted(analysis.element_counts)) + ) + self._refresh_suggested_output_dir() + except Exception as exc: + self.input_panel.set_summary_text(str(exc)) + self.input_panel.set_input_mode(None) + self.statusBar().showMessage("Input analysis failed") + self._update_command_preview() + + def _save_run_file(self, *_args: object) -> None: + try: + project_dir = self._require_project_dir() + config = self._current_config(project_dir) + except Exception as exc: + QMessageBox.warning(self, "XYZ -> PDB CLI Setup", str(exc)) + return + run_file_path = default_xyz2pdb_run_file_path(project_dir) + save_xyz2pdb_run_config(run_file_path, config) + self.run_file_edit.setText(str(run_file_path)) + self.json_preview_box.setPlainText(save_preview_text(config.to_dict())) + self._update_command_preview() + self.statusBar().showMessage(f"Saved run file: {run_file_path}") + QMessageBox.information( + self, + "XYZ -> PDB CLI Setup", + f"Saved XYZ -> PDB CLI run file:\n{run_file_path}", + ) + + def _update_command_preview(self, *_args: object) -> None: + project_dir = self._project_dir() + if project_dir is None: + self.command_box.setPlainText( + "Select a project folder before saving the CLI run file." + ) + self.json_preview_box.clear() + return + run_file_path = default_xyz2pdb_run_file_path(project_dir) + self.run_file_edit.setText(str(run_file_path)) + self.command_box.setPlainText( + f'xyz2pdb run "{project_dir}"\n' + f'saxshell xyz2pdb run "{project_dir}"' + ) + try: + config = self._current_config(project_dir) + except Exception as exc: + self.json_preview_box.setPlainText(str(exc)) + return + self.json_preview_box.setPlainText(save_preview_text(config.to_dict())) + + def _current_config(self, project_dir: Path): + input_path = self.input_panel.get_input_path() + if input_path is None: + raise ValueError("Choose an XYZ input before saving.") + return build_xyz2pdb_run_config( + project_dir=project_dir, + input_path=input_path, + output_dir=self._output_dir(), + reference_library_dir=self.reference_panel.get_library_dir(), + molecule_inputs=tuple(self.mapping_panel.get_molecule_inputs()), + free_atom_inputs=tuple(self.mapping_panel.get_free_atom_inputs()), + hydrogen_mode=self.mapping_panel.hydrogen_mode(), + selected_solution_index=int(self.selected_solution_spin.value()), + assertion_mode=bool(self.assertion_mode_checkbox.isChecked()), + pbc_params=self._pbc_params(), + ) + + def _refresh_run_file_path(self) -> None: + project_dir = self._project_dir() + if project_dir is None: + self.run_file_edit.clear() + return + self.run_file_edit.setText( + str(default_xyz2pdb_run_file_path(project_dir)) + ) + + def _refresh_suggested_output_dir(self) -> None: + input_path = self.input_panel.get_input_path() + if input_path is None: + return + try: + suggested = suggest_output_dir(input_path) + except Exception: + return + current = self.output_dir_edit.text().strip() + if not current or current == self._last_suggested_output_dir: + self.output_dir_edit.setText(str(suggested)) + self._last_suggested_output_dir = str(suggested) + + def _project_dir(self) -> Path | None: + text = self.project_dir_edit.text().strip() + if not text: + return None + return Path(text).expanduser().resolve() + + def _require_project_dir(self) -> Path: + project_dir = self._project_dir() + if project_dir is None: + raise ValueError("Choose a project folder before saving.") + if not project_dir.is_dir(): + raise ValueError(f"Project folder does not exist: {project_dir}") + return project_dir + + @staticmethod + def _project_frames_dir(project_dir: Path) -> Path | None: + from saxshell.saxs.project_manager import SAXSProjectManager + + try: + settings = SAXSProjectManager().load_project(project_dir) + except Exception: + return None + return settings.resolved_frames_dir + + def _output_dir(self) -> Path | None: + text = self.output_dir_edit.text().strip() + return Path(text) if text else None + + def _pbc_params(self) -> dict[str, float | str]: + text = self.pbc_params_edit.toPlainText().strip() + if not text: + return {} + payload = json.loads(text) + if not isinstance(payload, dict): + raise ValueError("PBC JSON must be an object.") + return dict(payload) + + +def save_preview_text(payload: dict[str, object]) -> str: + return json.dumps(payload, indent=2) + + +def launch_xyz2pdb_run_file_ui( + *, + initial_project_dir: str | Path | None = None, + initial_input_path: str | Path | None = None, +) -> XYZToPDBRunFileWindow: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) + window = XYZToPDBRunFileWindow( + initial_project_dir=initial_project_dir, + initial_input_path=initial_input_path, + ) + window.show() + window.raise_() + return window + + +__all__ = [ + "XYZToPDBRunFileWindow", + "launch_xyz2pdb_run_file_ui", +] diff --git a/tests/template_candidates/valid_installable_model.py b/tests/template_candidates/valid_installable_model.py index eefc268..87dde9a 100644 --- a/tests/template_candidates/valid_installable_model.py +++ b/tests/template_candidates/valid_installable_model.py @@ -37,4 +37,8 @@ def log_likelihood_candidate(params): offset=offset, ) residuals = np.asarray(experimental_intensities, dtype=float) - model + if residuals.size == 0: + return -1.0 + if not np.all(np.isfinite(residuals)): + return -1.0e12 return float(-0.5 * np.mean(residuals**2)) diff --git a/tests/test_cluster_cli.py b/tests/test_cluster_cli.py index e8cdc5b..454345e 100644 --- a/tests/test_cluster_cli.py +++ b/tests/test_cluster_cli.py @@ -1,16 +1,44 @@ from __future__ import annotations +import json from pathlib import Path from saxshell.cluster import ( DEFAULT_SAVE_STATE_FREQUENCY, ClusterWorkflow, + build_cluster_run_config, + default_cluster_run_file_path, example_atom_type_definitions, example_pair_cutoff_definitions, + load_cluster_run_config, + resolve_run_config_path, + run_cluster_run_config, + save_cluster_run_config, ) from saxshell.cluster.cli import main as cluster_main +def _create_project(project_dir: Path) -> None: + project_dir.mkdir(parents=True, exist_ok=True) + (project_dir / "saxs_project.json").write_text( + json.dumps( + { + "project_name": project_dir.name, + "project_dir": str(project_dir.resolve()), + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + + +def _read_project(project_dir: Path) -> dict[str, object]: + return json.loads( + (project_dir / "saxs_project.json").read_text(encoding="utf-8") + ) + + def _write_xyz_frame( path: Path, *, @@ -131,3 +159,123 @@ def test_clusters_cli_export_runs_complete_headless_workflow( "frame_0001_AAA.xyz", "frame_0001_AAB.xyz", ] + + +def test_cluster_run_config_round_trips_project_relative_paths(tmp_path): + project_dir = tmp_path / "project" + project_dir.mkdir() + frames_dir = project_dir / "frames" / "splitxyz0001" + frames_dir.mkdir(parents=True) + output_dir = project_dir / "clusters" / "splitxyz0001" + + config = build_cluster_run_config( + project_dir=project_dir, + frames_dir=frames_dir, + output_dir=output_dir, + atom_type_definitions=example_atom_type_definitions(), + pair_cutoff_definitions=example_pair_cutoff_definitions(), + use_pbc=True, + search_mode="vectorized", + save_state_frequency=250, + ) + run_file = default_cluster_run_file_path(project_dir) + save_cluster_run_config(run_file, config) + + loaded = load_cluster_run_config(run_file) + + assert loaded.frames_dir == "frames/splitxyz0001" + assert loaded.output_dir == "clusters/splitxyz0001" + assert loaded.use_pbc is True + assert loaded.search_mode == "vectorized" + assert loaded.save_state_frequency == 250 + assert ( + resolve_run_config_path(loaded.frames_dir, project_dir=project_dir) + == frames_dir.resolve() + ) + + +def test_cluster_project_backed_run_updates_project_clusters_dir( + tmp_path, +): + project_dir = tmp_path / "project" + _create_project(project_dir) + frames_dir = project_dir / "frames" / "splitxyz0001" + frames_dir.mkdir(parents=True) + _write_xyz_frame( + frames_dir / "frame_0000.xyz", pb1_x=0.0, i_x=1.0, pb2_x=2.0 + ) + project_payload = _read_project(project_dir) + project_payload["frames_dir"] = str(frames_dir) + (project_dir / "saxs_project.json").write_text( + json.dumps(project_payload, indent=2) + "\n", + encoding="utf-8", + ) + + output_dir = project_dir / "clusters_splitxyz0001" + config = build_cluster_run_config( + project_dir=project_dir, + frames_dir=frames_dir, + output_dir=output_dir, + atom_type_definitions=example_atom_type_definitions(), + pair_cutoff_definitions=example_pair_cutoff_definitions(), + smart_solvation_shells=False, + ) + save_cluster_run_config(default_cluster_run_file_path(project_dir), config) + + exit_code = cluster_main(["run", str(project_dir)]) + saved_settings = _read_project(project_dir) + + assert exit_code == 0 + assert ( + Path(str(saved_settings["frames_dir"])).resolve() + == frames_dir.resolve() + ) + assert ( + Path(str(saved_settings["clusters_dir"])).resolve() + == output_dir.resolve() + ) + assert saved_settings["clusters_dir_snapshot"] is not None + assert output_dir.exists() + assert sorted(path.name for path in output_dir.rglob("*.xyz")) == [ + "frame_0000_AAA.xyz", + ] + + +def test_run_cluster_run_config_preserves_existing_pdb_frames_field(tmp_path): + project_dir = tmp_path / "project" + _create_project(project_dir) + pdb_frames_dir = tmp_path / "pdb_frames" + pdb_frames_dir.mkdir() + frames_dir = project_dir / "frames" / "splitxyz0001" + frames_dir.mkdir(parents=True) + _write_xyz_frame( + frames_dir / "frame_0000.xyz", pb1_x=0.0, i_x=1.0, pb2_x=2.0 + ) + project_payload = _read_project(project_dir) + project_payload["pdb_frames_dir"] = str(pdb_frames_dir) + (project_dir / "saxs_project.json").write_text( + json.dumps(project_payload, indent=2) + "\n", + encoding="utf-8", + ) + + config = build_cluster_run_config( + project_dir=project_dir, + frames_dir=frames_dir, + output_dir=project_dir / "clusters_splitxyz0001", + atom_type_definitions=example_atom_type_definitions(), + pair_cutoff_definitions=example_pair_cutoff_definitions(), + smart_solvation_shells=False, + ) + + summary = run_cluster_run_config(project_dir, config) + saved_settings = _read_project(project_dir) + + assert summary.written_count == 1 + assert ( + Path(str(saved_settings["pdb_frames_dir"])).resolve() + == pdb_frames_dir.resolve() + ) + assert ( + Path(str(saved_settings["clusters_dir"])).resolve() + == summary.output_dir.resolve() + ) diff --git a/tests/test_cluster_ui.py b/tests/test_cluster_ui.py index 52dd8f9..be0e8b5 100644 --- a/tests/test_cluster_ui.py +++ b/tests/test_cluster_ui.py @@ -1,7 +1,9 @@ import json import os +from pathlib import Path import pytest +from PySide6.QtCore import Qt from PySide6.QtWidgets import QApplication, QComboBox, QTableWidgetItem import saxshell.cluster.cli as cluster_cli_module @@ -13,6 +15,13 @@ example_atom_type_definitions, example_pair_cutoff_definitions, ) +from saxshell.cluster.ui.batch_queue_window import ( + ClusterBatchItem, + ClusterBatchJob, + ClusterBatchQueueWindow, + ClusterBatchResult, + ClusterBatchWorker, +) from saxshell.cluster.ui.definitions_panel import ClusterDefinitionsPanel from saxshell.cluster.ui.export_panel import ClusterExportPanel from saxshell.cluster.ui.main_window import ( @@ -24,6 +33,7 @@ estimate_selection, suggest_cluster_output_dir, ) +from saxshell.cluster.ui.run_file_window import ClusterRunFileWindow from saxshell.saxs.project_manager import SAXSProjectManager @@ -246,7 +256,7 @@ def test_cluster_main_window_preview_includes_output_details( assert "Search mode: KDTree" in text assert ( "Save-state frequency: every " - f"{DEFAULT_SAVE_STATE_FREQUENCY} frames" in text + f"{DEFAULT_SAVE_STATE_FREQUENCY} frame(s)" in text ) assert "Stoichiometry bins: solute only" in text assert "Frames selected: 10" in text @@ -340,6 +350,170 @@ def test_cluster_main_window_can_toggle_between_project_xyz_and_pdb_folders( window.close() +def test_cluster_batch_queue_prefills_project_pdb_frames_and_default_preset( + qapp, + tmp_path, +): + del qapp + manager = SAXSProjectManager() + project_dir = tmp_path / "saxs_project" + settings = manager.create_project(project_dir) + xyz_frames_dir = tmp_path / "splitxyz0001" + xyz_frames_dir.mkdir() + (xyz_frames_dir / "frame_0000.xyz").write_text( + "2\nframe_0000\nPb 0.0 0.0 0.0\nI 1.0 0.0 0.0\n", + encoding="utf-8", + ) + pdb_frames_dir = tmp_path / "xyz2pdb_splitxyz0001" + pdb_frames_dir.mkdir() + (pdb_frames_dir / "frame_0000.pdb").write_text( + "MODEL 1\n" + "ATOM 1 PB1 SOL X 1 0.000 0.000 0.000" + " 1.00 0.00 PB\n" + "ATOM 2 I1 SOL X 1 1.000 0.000 0.000" + " 1.00 0.00 I\n" + "ENDMDL\n", + encoding="utf-8", + ) + settings.frames_dir = str(xyz_frames_dir) + settings.pdb_frames_dir = str(pdb_frames_dir) + manager.save_project(settings) + + window = ClusterBatchQueueWindow(initial_project_dir=project_dir) + + assert window.queue_list.count() == 1 + list_item = window.queue_list.item(0) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = window._widgets_by_id[item_id] + assert widget.frames_dir_edit.text() == str(pdb_frames_dir.resolve()) + assert widget.item().frames_source_kind == "pdb" + assert widget.definitions_panel.atom_type_definitions() == { + "node": [("Pb", None)], + "linker": [("I", None)], + "shell": [("O", None)], + } + assert "Mode: PDB frames" in widget.summary_box.toPlainText() + assert Path(widget.output_dir_edit.text()).name == ( + "clusters_xyz2pdb_splitxyz0001" + ) + window.close() + + +def test_cluster_batch_worker_exports_and_registers_clusters_folder( + qapp, + tmp_path, +): + del qapp + manager = SAXSProjectManager() + project_dir = tmp_path / "saxs_project" + settings = manager.create_project(project_dir) + pdb_frames_dir = tmp_path / "xyz2pdb_splitxyz0001" + pdb_frames_dir.mkdir() + (pdb_frames_dir / "frame_0000.pdb").write_text( + "MODEL 1\n" + "ATOM 1 PB1 SOL X 1 0.000 0.000 0.000" + " 1.00 0.00 PB\n" + "ATOM 2 I1 SOL X 1 1.000 0.000 0.000" + " 1.00 0.00 I\n" + "ENDMDL\n", + encoding="utf-8", + ) + settings.pdb_frames_dir = str(pdb_frames_dir) + manager.save_project(settings) + output_dir = tmp_path / "clusters_xyz2pdb_splitxyz0001" + job = ClusterBatchJob( + project_dir=project_dir, + frames_dir=pdb_frames_dir, + frames_source_kind="pdb", + config=ClusterJobConfig( + frames_dir=pdb_frames_dir, + atom_type_definitions=example_atom_type_definitions(), + pair_cutoff_definitions=example_pair_cutoff_definitions(), + box_dimensions=None, + use_pbc=False, + search_mode="kdtree", + save_state_frequency=250, + default_cutoff=None, + shell_levels=(), + include_shell_levels=(0,), + shared_shells=False, + smart_solvation_shells=True, + include_shell_atoms_in_stoichiometry=False, + output_dir=output_dir, + ), + ) + worker = ClusterBatchWorker([("job-1", job)]) + failures = [] + finished_items = [] + finished_batches = [] + worker.failed.connect( + lambda item_id, message: failures.append((item_id, message)) + ) + worker.item_finished.connect( + lambda item_id, result: finished_items.append((item_id, result)) + ) + worker.finished.connect(finished_batches.append) + + worker.run() + + assert failures == [] + assert len(finished_items) == 1 + item_id, result = finished_items[0] + assert item_id == "job-1" + assert result.output_dir == output_dir.resolve() + assert result.written_count >= 1 + assert output_dir.is_dir() + saved_settings = manager.load_project(project_dir) + assert saved_settings.resolved_pdb_frames_dir == pdb_frames_dir.resolve() + assert saved_settings.resolved_clusters_dir == output_dir.resolve() + assert saved_settings.clusters_dir_snapshot is not None + assert finished_batches == [[result]] + + +def test_cluster_batch_queue_emits_registered_clusters_folder(qapp, tmp_path): + del qapp + project_dir = tmp_path / "saxs_project" + SAXSProjectManager().create_project(project_dir) + frames_dir = tmp_path / "xyz2pdb_splitxyz0001" + frames_dir.mkdir() + output_dir = tmp_path / "clusters_xyz2pdb_splitxyz0001" + output_dir.mkdir() + + window = ClusterBatchQueueWindow() + widget = window.add_queue_item( + ClusterBatchItem( + item_id="job-1", + project_dir=project_dir, + frames_dir=frames_dir, + output_dir=output_dir, + ) + ) + updates = [] + window.project_paths_registered.connect(updates.append) + + window._on_item_finished( + widget.item_id, + ClusterBatchResult( + project_dir=project_dir.resolve(), + frames_dir=frames_dir.resolve(), + frames_source_kind="pdb", + output_dir=output_dir.resolve(), + analyzed_frames=1, + total_clusters=1, + written_count=1, + ), + ) + + assert updates == [ + { + "project_dir": project_dir.resolve(), + "clusters_dir": output_dir.resolve(), + } + ] + assert widget.status_label.text() == "Complete" + window.close() + + def test_cluster_main_window_switches_to_xyz_mode(qapp, tmp_path): source_dir = tmp_path / "cluster_run" source_dir.mkdir() @@ -603,3 +777,39 @@ def fake_cluster_main(argv=None): assert exit_code == 9 assert captured["argv"] == ["ui", "traj.pdb"] + + +def test_cluster_run_file_window_builds_project_relative_config( + qapp, + tmp_path, +): + del qapp + project_dir = tmp_path / "project" + SAXSProjectManager().create_project(project_dir) + frames_dir = project_dir / "frames" / "splitxyz0001" + frames_dir.mkdir(parents=True) + (frames_dir / "frame_0000.xyz").write_text( + "2\nframe_0000\nPb 0.0 0.0 0.0\nI 1.0 0.0 0.0\n", + encoding="utf-8", + ) + + window = ClusterRunFileWindow( + initial_project_dir=project_dir, + initial_frames_dir=frames_dir, + ) + output_dir = project_dir / "clusters_splitxyz0001" + window.output_dir_edit.setText(str(output_dir)) + window.definitions_panel.set_use_pbc(True) + window.definitions_panel.set_search_mode("vectorized") + window.definitions_panel.set_save_state_frequency(250) + + config = window._current_config(project_dir) + + assert config.frames_dir == "frames/splitxyz0001" + assert config.output_dir == "clusters_splitxyz0001" + assert config.use_pbc is True + assert config.search_mode == "vectorized" + assert config.save_state_frequency == 250 + assert config.atom_type_definitions["node"] == [("Pb", None)] + assert config.pair_cutoff_definitions[("Pb", "I")] == {0: 3.36} + window.close() diff --git a/tests/test_clusterdynamics.py b/tests/test_clusterdynamics.py index 31f4467..df49fab 100644 --- a/tests/test_clusterdynamics.py +++ b/tests/test_clusterdynamics.py @@ -13,11 +13,20 @@ from saxshell import saxshell as saxshell_module from saxshell.clusterdynamics import ( ClusterDynamicsWorkflow, + build_clusterdynamics_run_config, + default_clusterdynamics_run_file_path, load_cluster_dynamics_dataset, + load_clusterdynamics_run_config, + resolve_run_config_path, + run_clusterdynamics_run_config, save_cluster_dynamics_dataset, + save_clusterdynamics_run_config, ) from saxshell.clusterdynamics.ui.main_window import ClusterDynamicsMainWindow from saxshell.clusterdynamics.ui.plot_panel import ClusterDynamicsPlotPanel +from saxshell.clusterdynamics.ui.run_file_window import ( + ClusterDynamicsRunFileWindow, +) from saxshell.plotting import ( igor_inline_to_mathtext, load_pickled_plot_figure, @@ -361,6 +370,176 @@ def test_cluster_dynamics_dataset_round_trip(tmp_path): assert loaded.result.energy_data is not None +def test_cluster_dynamics_run_config_round_trips_project_relative_paths( + tmp_path, +): + project_dir = tmp_path / "project" + project_dir.mkdir() + frames_dir = project_dir / "frames" / "splitxyz0001" + frames_dir.mkdir(parents=True) + energy_file = project_dir / "traj.ener" + energy_file.write_text("0 0.0 1.0 300.0 -10.0\n", encoding="utf-8") + output_file = ( + project_dir + / "exported_results" + / "data" + / "clusterdynamics" + / "splitxyz0001_cluster_dynamics.json" + ) + + config = build_clusterdynamics_run_config( + project_dir=project_dir, + frames_dir=frames_dir, + output_file=output_file, + energy_file=energy_file, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + shell_levels=(1,), + frame_timestep_fs=10.0, + frames_per_colormap_timestep=3, + search_mode="vectorized", + ) + run_file = default_clusterdynamics_run_file_path(project_dir) + save_clusterdynamics_run_config(run_file, config) + + loaded = load_clusterdynamics_run_config(run_file) + + assert loaded.frames_dir == "frames/splitxyz0001" + assert loaded.energy_file == "traj.ener" + assert loaded.output_file == ( + "exported_results/data/clusterdynamics/" + "splitxyz0001_cluster_dynamics.json" + ) + assert loaded.shell_levels == (1,) + assert loaded.frame_timestep_fs == 10.0 + assert loaded.frames_per_colormap_timestep == 3 + assert loaded.search_mode == "vectorized" + assert ( + resolve_run_config_path(loaded.frames_dir, project_dir=project_dir) + == frames_dir.resolve() + ) + + +def test_cluster_dynamics_project_run_saves_dataset_and_updates_project( + tmp_path, + capsys, +): + manager = SAXSProjectManager() + project_dir = tmp_path / "project" + manager.create_project(project_dir) + frames_dir = project_dir / "frames" / "splitxyz0001" + frames_dir.mkdir(parents=True) + for index, content in enumerate( + ( + _disconnected_xyz_lines(), + _connected_xyz_lines(), + _connected_xyz_lines(), + _disconnected_xyz_lines(), + _connected_xyz_lines(), + _disconnected_xyz_lines(), + ) + ): + (frames_dir / f"frame_{index:04d}.xyz").write_text( + content, + encoding="utf-8", + ) + energy_file = _write_energy_file(project_dir) + output_file = ( + project_dir + / "exported_results" + / "data" + / "clusterdynamics" + / "splitxyz0001_cluster_dynamics.json" + ) + config = build_clusterdynamics_run_config( + project_dir=project_dir, + frames_dir=frames_dir, + output_file=output_file, + energy_file=energy_file, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + shell_levels=(1,), + frame_timestep_fs=10.0, + frames_per_colormap_timestep=3, + ) + run_file = default_clusterdynamics_run_file_path(project_dir) + save_clusterdynamics_run_config(run_file, config) + + summary = run_clusterdynamics_run_config( + project_dir, + load_clusterdynamics_run_config(run_file), + run_file_path=run_file, + ) + saved_settings = manager.load_project(project_dir) + + assert summary.result.analyzed_frames == 6 + assert summary.output_file == output_file.resolve() + assert output_file.is_file() + assert output_file.with_name( + "splitxyz0001_cluster_dynamics_lifetime.csv" + ).is_file() + assert saved_settings.resolved_frames_dir == frames_dir.resolve() + assert saved_settings.resolved_energy_file == energy_file.resolve() + assert saved_settings.frames_dir_snapshot is not None + assert saved_settings.energy_file_snapshot is not None + + exit_code = clusterdynamics_cli_module.main(["run", str(project_dir)]) + captured = capsys.readouterr() + + assert exit_code == 0 + assert "Cluster dynamics CLI run complete" in captured.out + assert "Lifetime rows:" in captured.out + assert "Files written:" in captured.out + + +def test_cluster_dynamics_run_file_window_builds_project_config( + qapp, + tmp_path, +): + del qapp + project_dir = tmp_path / "project" + SAXSProjectManager().create_project(project_dir) + frames_dir = project_dir / "frames" / "splitxyz0001" + frames_dir.mkdir(parents=True) + (frames_dir / "frame_0000.xyz").write_text( + "2\nframe_0000\nPb 0.0 0.0 0.0\nI 1.0 0.0 0.0\n", + encoding="utf-8", + ) + energy_file = project_dir / "traj.ener" + energy_file.write_text("0 0.0 1.0 300.0 -10.0\n", encoding="utf-8") + + window = ClusterDynamicsRunFileWindow( + initial_project_dir=project_dir, + initial_frames_dir=frames_dir, + initial_energy_file=energy_file, + ) + output_file = ( + project_dir + / "exported_results" + / "data" + / "clusterdynamics" + / "splitxyz0001_cluster_dynamics.json" + ) + window.output_file_edit.setText(str(output_file)) + window.time_panel.set_frame_timestep_fs(10.0) + window.time_panel.set_frames_per_colormap_timestep(3) + window.definitions_panel.set_search_mode("vectorized") + + config = window._current_config(project_dir) + + assert config.frames_dir == "frames/splitxyz0001" + assert config.energy_file == "traj.ener" + assert config.output_file == ( + "exported_results/data/clusterdynamics/" + "splitxyz0001_cluster_dynamics.json" + ) + assert config.frame_timestep_fs == 10.0 + assert config.frames_per_colormap_timestep == 3 + assert config.search_mode == "vectorized" + assert "clusterdynamics run" in window.command_box.toPlainText() + window.close() + + def test_cluster_dynamics_main_window_updates_preview_for_xyz_frames( qapp, tmp_path, diff --git a/tests/test_clusterdynamicsml.py b/tests/test_clusterdynamicsml.py index 1c1b067..7ddc98f 100644 --- a/tests/test_clusterdynamicsml.py +++ b/tests/test_clusterdynamicsml.py @@ -22,8 +22,12 @@ from saxshell.cluster import PDBShellReferenceDefinition from saxshell.clusterdynamicsml import ( ClusterDynamicsMLWorkflow, + build_clusterdynamicsml_run_config, + default_clusterdynamicsml_run_file_path, load_cluster_dynamicsai_dataset, + load_clusterdynamicsml_run_config, save_cluster_dynamicsai_dataset, + save_clusterdynamicsml_run_config, ) from saxshell.clusterdynamicsml.ui.main_window import ( _UI_REFRESH_DELAY_MS, @@ -34,6 +38,10 @@ from saxshell.clusterdynamicsml.ui.plot_panel import ( _build_population_histogram_payload, _distribution_entries, + build_cluster_lifetime_distributions, +) +from saxshell.clusterdynamicsml.ui.run_file_window import ( + ClusterDynamicsMLRunFileWindow, ) from saxshell.saxs.debye import ( compute_debye_intensity, @@ -715,6 +723,143 @@ def test_clusterdynamicsml_workflow_predicts_larger_clusters(tmp_path): assert len(result.saxs_comparison.component_weights) >= 3 +def test_clusterdynamicsml_project_run_saves_dataset_and_updates_project( + tmp_path, + capsys, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + energy_file = _write_energy_file(tmp_path) + project_dir = _build_project_dir( + tmp_path, + frames_dir=frames_dir, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + energy_file=energy_file, + ) + output_file = ( + project_dir + / "exported_results" + / "data" + / "clusterdynamicsml" + / "splitxyz_f0fs_clusterdynamicsml.json" + ) + config = build_clusterdynamicsml_run_config( + project_dir=project_dir, + frames_dir=frames_dir, + output_file=output_file, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + energy_file=energy_file, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4,), + candidates_per_size=1, + q_points=60, + ) + run_file = default_clusterdynamicsml_run_file_path(project_dir) + save_clusterdynamicsml_run_config(run_file, config) + loaded = load_clusterdynamicsml_run_config(run_file) + + assert loaded.frames_dir == str(frames_dir.resolve()) + assert loaded.output_file == ( + "exported_results/data/clusterdynamicsml/" + "splitxyz_f0fs_clusterdynamicsml.json" + ) + assert loaded.target_node_counts == (4,) + assert loaded.candidates_per_size == 1 + assert loaded.q_points == 60 + + exit_code = clusterdynamicsml_cli_module.main(["run", str(project_dir)]) + captured = capsys.readouterr() + saved_settings = SAXSProjectManager().load_project(project_dir) + + assert exit_code == 0 + assert "Cluster dynamics ML CLI run complete" in captured.out + assert "Predictions:" in captured.out + assert output_file.is_file() + assert output_file.with_name( + "splitxyz_f0fs_clusterdynamicsml_predictions.csv" + ).is_file() + loaded_dataset = load_cluster_dynamicsai_dataset(output_file) + assert loaded_dataset.result.predictions + assert saved_settings.resolved_frames_dir == frames_dir.resolve() + assert saved_settings.resolved_clusters_dir == clusters_dir.resolve() + assert saved_settings.resolved_energy_file == energy_file.resolve() + + +def test_clusterdynamicsml_run_file_window_builds_project_config( + qapp, + tmp_path, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + energy_file = _write_energy_file(tmp_path) + project_dir = _build_project_dir( + tmp_path, + frames_dir=frames_dir, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + energy_file=energy_file, + ) + + window = ClusterDynamicsMLRunFileWindow( + initial_project_dir=project_dir, + initial_frames_dir=frames_dir, + initial_energy_file=energy_file, + initial_clusters_dir=clusters_dir, + initial_experimental_data_file=experimental_data_file, + ) + output_file = ( + project_dir + / "exported_results" + / "data" + / "clusterdynamicsml" + / "splitxyz_f0fs_clusterdynamicsml.json" + ) + window.output_file_edit.setText(str(output_file)) + window.definitions_panel.load_atom_type_definitions( + ATOM_TYPE_DEFINITIONS, + emit_signal=False, + ) + window.definitions_panel.load_pair_cutoff_definitions( + PAIR_CUTOFFS, + emit_signal=False, + ) + window.time_panel.set_frame_timestep_fs(10.0) + window.time_panel.set_frames_per_colormap_timestep(1) + window.prediction_panel.set_target_node_counts((4,)) + window.prediction_panel.set_candidates_per_size(1) + window.prediction_panel.set_q_settings( + q_min=0.05, + q_max=1.0, + q_points=60, + ) + + config = window._current_config(project_dir) + + assert config.frames_dir == str(frames_dir.resolve()) + assert config.clusters_dir == str(clusters_dir.resolve()) + assert config.experimental_data_file == str( + experimental_data_file.resolve() + ) + assert config.energy_file == str(energy_file.resolve()) + assert config.output_file == ( + "exported_results/data/clusterdynamicsml/" + "splitxyz_f0fs_clusterdynamicsml.json" + ) + assert config.target_node_counts == (4,) + assert config.candidates_per_size == 1 + assert config.q_points == 60 + assert "clusterdynamicsml run" in window.command_box.toPlainText() + window.close() + + def test_clusterdynamicsml_estimates_debye_waller_pairs_and_uses_them_for_predicted_traces( tmp_path, ): @@ -2861,6 +3006,56 @@ def test_clusterdynamicsml_window_shows_observed_lifetime_tab( window.close() +def test_clusterdynamicsml_window_opens_lifetime_distribution_plots( + qapp, + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + distributions = build_cluster_lifetime_distributions(result) + + window = ClusterDynamicsMLMainWindow(initial_frames_dir=frames_dir) + window._on_run_finished(result) + window.lifetime_distribution_button.click() + qapp.processEvents() + + distribution_window = window._lifetime_distribution_window + + assert distributions + assert any( + distribution.completed_lifetime_count > 0 + for distribution in distributions + ) + assert window.lifetime_distribution_button.isEnabled() + assert distribution_window is not None + assert distribution_window.isVisible() + assert "Lorentzian" in distribution_window.panel.summary_box.toPlainText() + assert any( + axis.patches + for axis in distribution_window.panel.figure.axes + if axis.get_visible() + ) + + distribution_window.panel.include_truncated_checkbox.setChecked(True) + qapp.processEvents() + + assert "window-truncated" in ( + distribution_window.panel.summary_box.toPlainText() + ) + distribution_window.close() + window.close() + + def test_clusterdynamicsml_window_progress_messages_show_current_ml_step( qapp, ): diff --git a/tests/test_fullrmc_cli.py b/tests/test_fullrmc_cli.py index ebf9112..42f13b6 100644 --- a/tests/test_fullrmc_cli.py +++ b/tests/test_fullrmc_cli.py @@ -16,6 +16,7 @@ ) import saxshell.fullrmc.cli as fullrmc_cli +import saxshell.fullrmc.packmol_setup as packmol_setup_module import saxshell.fullrmc.solvent_shell_builder as solvent_shell_builder_module import saxshell.fullrmc.ui.main_window as fullrmc_ui_module from saxshell.fullrmc import ( @@ -28,6 +29,7 @@ PackmolDockerValidationResult, PackmolPlanningSettings, PackmolSetupSettings, + PackmolSupplementalComponentSettings, RepresentativeSelectionSettings, SolutionPropertiesSettings, SolventHandlingSettings, @@ -56,6 +58,9 @@ select_first_file_representatives, ) from saxshell.fullrmc.cli import main as fullrmc_main +from saxshell.fullrmc.solvent_handling import ( + analyze_representative_solvent_distribution, +) from saxshell.fullrmc.ui.main_window import RMCSetupMainWindow from saxshell.fullrmc.ui.solvent_shell_builder_window import ( SolventShellBuilderMainWindow, @@ -1909,6 +1914,48 @@ def test_build_representative_solvent_outputs_preserves_single_atom_sources( assert [atom.element for atom in completed_structure.atoms] == ["Zn"] +def test_representative_solvent_distribution_ignores_single_atom_status( + tmp_path, +): + project_dir, _paths, _single_atom_path = ( + _build_sample_saxs_project_with_single_atom_model(tmp_path) + ) + reference_path = _write_custom_solvent_pdb(tmp_path) + complete_solvent_path = _write_test_solvent_shell_pdb( + tmp_path, + reference_path=reference_path, + ) + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + state.representative_selection = select_first_file_representatives( + state, + selection, + ) + for entry in state.representative_selection.representative_entries: + if entry.atom_count > 1: + entry.source_file = str(complete_solvent_path) + entry.source_file_name = complete_solvent_path.name + + analysis = analyze_representative_solvent_distribution( + state, + _integrated_solvent_handling_settings( + reference_source="custom", + custom_reference_path=str(reference_path), + director_atom_name="O1", + ), + ) + + assert analysis.distribution_status == "complete_solvent" + assert analysis.distribution_status_entry_count == 2 + assert analysis.ignored_distribution_status_entry_count == 1 + assert "Ignored 1 single-atom representative" in ( + analysis.distribution_note + ) + assert "Zn1: No solvent molecules detected" in analysis.summary_text() + assert "ignored for distribution state" in analysis.summary_text() + + def test_build_packmol_plan_writes_metadata_and_reports(tmp_path): project_dir, _paths = _build_sample_saxs_project(tmp_path) state = load_rmc_project_source(project_dir) @@ -2134,6 +2181,107 @@ def test_build_packmol_plan_tracks_solvent_allocation(tmp_path): assert "Cluster solvent molecules:" in metadata.summary_text() +def test_build_packmol_plan_allocates_missing_solute_components(tmp_path): + project_dir, _paths = _build_sample_saxs_project(tmp_path) + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + state.representative_selection = select_first_file_representatives( + state, + selection, + ) + solution_settings = SolutionPropertiesSettings( + mode="mass", + solution_density=1.05, + solute_stoich="C1H6N1Pb1I2", + solvent_stoich="C3H7NO", + molar_mass_solute=493.0, + molar_mass_solvent=73.09, + mass_solute=4.93, + mass_solvent=95.07, + ) + state.solution_properties = save_solution_properties_metadata( + state.rmcsetup_paths.solution_properties_path, + settings=solution_settings, + result=calculate_solution_properties(solution_settings), + ) + + metadata = build_packmol_plan( + state, + PackmolPlanningSettings( + planning_mode="per_element", + box_side_length_a=80.0, + supplemental_components=( + PackmolSupplementalComponentSettings( + role="solute", + reference="ma", + residue_name="MAI", + ), + ), + ), + ) + + allocation = metadata.supplemental_allocation + + assert allocation is not None + assert allocation.target_solute_formula_units > 0 + assert allocation.unfilled_solute_element_totals == {} + assert allocation.entries[0].planned_count == ( + allocation.target_solute_formula_units + ) + assert allocation.entries[0].element_counts == { + "C": 1, + "H": 6, + "N": 1, + } + assert metadata.achieved_element_number_density_a3["C"] > 0 + assert "Supplemental solute accounting:" in ( + state.rmcsetup_paths.packmol_plan_report_path.read_text( + encoding="utf-8" + ) + ) + + +def test_build_packmol_plan_requires_components_for_absent_solute_species( + tmp_path, +): + project_dir, _paths = _build_sample_saxs_project(tmp_path) + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + state.representative_selection = select_first_file_representatives( + state, + selection, + ) + solution_settings = SolutionPropertiesSettings( + mode="mass", + solution_density=1.05, + solute_stoich="C1H6N1Pb1I2", + solvent_stoich="C3H7NO", + molar_mass_solute=493.0, + molar_mass_solvent=73.09, + mass_solute=4.93, + mass_solvent=95.07, + ) + state.solution_properties = save_solution_properties_metadata( + state.rmcsetup_paths.solution_properties_path, + settings=solution_settings, + result=calculate_solution_properties(solution_settings), + ) + + with pytest.raises( + ValueError, + match="Supplemental solute components are required", + ): + build_packmol_plan( + state, + PackmolPlanningSettings( + planning_mode="per_element", + box_side_length_a=80.0, + ), + ) + + def test_build_packmol_setup_writes_input_files_and_audit(tmp_path): project_dir, _paths = _build_sample_saxs_project(tmp_path) state = load_rmc_project_source(project_dir) @@ -2221,6 +2369,247 @@ def test_build_packmol_setup_writes_input_files_and_audit(tmp_path): assert "# Packmol Build Audit" in audit_text assert "Cluster solvent molecules:" in audit_text assert "Count-normalized weights" in audit_text + assert Path(metadata.build_report_path).is_file() + build_report_text = Path(metadata.build_report_path).read_text( + encoding="utf-8" + ) + assert "Source input information" in build_report_text + assert "Computed solvent molecules:" in build_report_text + assert "Cluster solvent molecules:" in build_report_text + assert "Free solvent molecules:" in build_report_text + assert "Target total number density:" in build_report_text + + solvated_entry = next( + entry for entry in metadata.entries if entry.solvent_atom_count > 0 + ) + solvated_structure = PDBStructure.from_file(solvated_entry.packmol_pdb) + solute_atoms = [ + atom + for atom in solvated_structure.atoms + if atom.residue_name == solvated_entry.residue_name + ] + solvent_atoms = [ + atom + for atom in solvated_structure.atoms + if atom.residue_name != solvated_entry.residue_name + ] + + assert solute_atoms + assert solvent_atoms + assert len(solute_atoms) == solvated_entry.solute_atom_count + assert len(solvent_atoms) == solvated_entry.solvent_atom_count + assert {atom.residue_number for atom in solute_atoms} == {1} + assert {atom.residue_name for atom in solvent_atoms} == {"DMF"} + assert min(atom.residue_number for atom in solvent_atoms) >= 2 + assert len({atom.residue_number for atom in solvent_atoms}) == ( + solvated_entry.solvent_residue_count + ) + + +def test_packmol_preparation_keeps_solvent_residue_when_solute_is_last(): + source_structure = PDBStructure( + atoms=[ + PDBAtom( + atom_id=229, + atom_name="O1", + residue_name="DMF", + residue_number=20, + coordinates=np.asarray([10.073, 14.645, 2.873]), + element="O", + ), + PDBAtom( + atom_id=230, + atom_name="N1", + residue_name="DMF", + residue_number=20, + coordinates=np.asarray([8.539, 14.935, 4.538]), + element="N", + ), + PDBAtom( + atom_id=639, + atom_name="PB3", + residue_name="PBI", + residue_number=56, + coordinates=np.asarray([13.198, 14.311, 3.729]), + element="Pb", + ), + PDBAtom( + atom_id=642, + atom_name="I3", + residue_name="PBI", + residue_number=59, + coordinates=np.asarray([16.059, 14.562, 4.307]), + element="I", + ), + PDBAtom( + atom_id=643, + atom_name="I4", + residue_name="PBI", + residue_number=60, + coordinates=np.asarray([12.481, 15.296, 6.518]), + element="I", + ), + ], + source_name="solvent_first_cluster", + ) + + prepared = packmol_setup_module._prepare_packmol_structure( + source_structure, + residue_name="CAH", + solvent_residue_names=frozenset({"DMF"}), + expected_solute_element_counts={"Pb": 1, "I": 2}, + solute_atom_count=2, + ) + + residue_names = [atom.residue_name for atom in prepared.structure.atoms] + residue_numbers = [ + atom.residue_number for atom in prepared.structure.atoms + ] + + assert residue_names == ["DMF", "DMF", "CAH", "CAH", "CAH"] + assert residue_numbers == [2, 2, 1, 1, 1] + assert prepared.solute_atom_count == 3 + assert prepared.solvent_atom_count == 2 + assert prepared.solvent_residue_names == ("DMF",) + + +def test_packmol_preparation_can_identify_formula_solute_without_metadata(): + source_structure = PDBStructure( + atoms=[ + PDBAtom( + atom_id=1, + atom_name="O1", + residue_name="DMF", + residue_number=20, + coordinates=np.asarray([0.0, 0.0, 0.0]), + element="O", + ), + PDBAtom( + atom_id=2, + atom_name="C1", + residue_name="DMF", + residue_number=20, + coordinates=np.asarray([1.0, 0.0, 0.0]), + element="C", + ), + PDBAtom( + atom_id=3, + atom_name="PB1", + residue_name="PBI", + residue_number=56, + coordinates=np.asarray([2.0, 0.0, 0.0]), + element="Pb", + ), + PDBAtom( + atom_id=4, + atom_name="I1", + residue_name="PBI", + residue_number=59, + coordinates=np.asarray([3.0, 0.0, 0.0]), + element="I", + ), + PDBAtom( + atom_id=5, + atom_name="I2", + residue_name="PBI", + residue_number=60, + coordinates=np.asarray([4.0, 0.0, 0.0]), + element="I", + ), + ], + source_name="formula_fallback_cluster", + ) + + prepared = packmol_setup_module._prepare_packmol_structure( + source_structure, + residue_name="CAH", + expected_solute_element_counts={"Pb": 1, "I": 2}, + ) + + assert [atom.residue_name for atom in prepared.structure.atoms] == [ + "DMF", + "DMF", + "CAH", + "CAH", + "CAH", + ] + assert prepared.solute_atom_count == 3 + assert prepared.solvent_residue_names == ("DMF",) + + +def test_build_packmol_setup_writes_supplemental_solute_components(tmp_path): + project_dir, _paths = _build_sample_saxs_project(tmp_path) + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + state.representative_selection = select_first_file_representatives( + state, + selection, + ) + solution_settings = SolutionPropertiesSettings( + mode="mass", + solution_density=1.05, + solute_stoich="C1H6N1Pb1I2", + solvent_stoich="C3H7NO", + molar_mass_solute=493.0, + molar_mass_solvent=73.09, + mass_solute=4.93, + mass_solvent=95.07, + ) + state.solution_properties = save_solution_properties_metadata( + state.rmcsetup_paths.solution_properties_path, + settings=solution_settings, + result=calculate_solution_properties(solution_settings), + ) + state.packmol_planning = build_packmol_plan( + state, + PackmolPlanningSettings( + planning_mode="per_element", + box_side_length_a=80.0, + free_solvent_reference="dmf", + supplemental_components=( + PackmolSupplementalComponentSettings( + role="solute", + reference="ma", + residue_name="MAI", + ), + ), + ), + ) + + metadata = build_packmol_setup( + state, + PackmolSetupSettings( + tolerance_angstrom=2.2, + free_solvent_reference="dmf", + ), + ) + + assert metadata.supplemental_entries + supplemental_entry = metadata.supplemental_entries[0] + supplemental_structure = PDBStructure.from_file( + supplemental_entry.packmol_pdb + ) + packmol_text = Path(metadata.packmol_input_path).read_text( + encoding="utf-8" + ) + build_report_text = Path(metadata.build_report_path).read_text( + encoding="utf-8" + ) + + assert supplemental_entry.planned_count == ( + state.packmol_planning.supplemental_allocation.target_solute_formula_units + ) + assert {atom.residue_name for atom in supplemental_structure.atoms} == { + "MAI" + } + assert supplemental_entry.atom_count == 8 + assert f"structure {Path(supplemental_entry.packmol_pdb).name}" in ( + packmol_text + ) + assert f" number {supplemental_entry.planned_count}" in packmol_text + assert "Supplemental solute accounting" in build_report_text + assert "Supplemental Packmol components" in build_report_text def test_build_packmol_setup_requires_all_positive_weight_representatives( @@ -2406,6 +2795,7 @@ def test_build_constraint_generation_writes_per_structure_and_merged_files( ) assert "BOND_ANGLE_CONSTRAINTS" in merged_text assert "BOND_LENGTH_CONSTRAINTS" in merged_text + assert "DMF" not in merged_text assert any(entry.bond_length_count > 0 for entry in metadata.entries) assert any(entry.bond_angle_count > 0 for entry in metadata.entries) @@ -3309,6 +3699,103 @@ def test_rmcsetup_solvent_handling_ui_builds_and_reloads(tmp_path): ) +def test_rmcsetup_reload_maps_representatives_to_current_dream_weights( + tmp_path, +): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + metadata = select_first_file_representatives(state, selection) + for entry in metadata.representative_entries: + entry.param = entry.structure + entry.selected_weight = 0.5 + for entry in metadata.distribution_selection.entries: + entry.param = entry.structure + entry.selected_weight = 0.5 + save_representative_selection_metadata( + state.rmcsetup_paths.representative_selection_path, + metadata, + ) + + window = RMCSetupMainWindow(initial_project_dir=project_dir) + + headers = [ + window.generated_pdb_table.horizontalHeaderItem(column).text() + for column in range(window.generated_pdb_table.columnCount()) + ] + assert "DREAM Weight" in headers + assert "DREAM Value" in headers + weight_column = headers.index("DREAM Weight") + value_column = headers.index("DREAM Value") + mapped = { + window.generated_pdb_table.item(row, 0).text(): ( + window.generated_pdb_table.item(row, weight_column).text(), + window.generated_pdb_table.item(row, value_column).text(), + ) + for row in range(window.generated_pdb_table.rowCount()) + } + + assert mapped["PbI2"] == ("w0", "0.25") + assert mapped["PbI2O/motif_1"] == ("w1", "0.75") + reloaded = load_representative_selection_metadata( + state.rmcsetup_paths.representative_selection_path + ) + assert reloaded is not None + assert { + (entry.structure, entry.motif): (entry.param, entry.selected_weight) + for entry in reloaded.representative_entries + } == { + ("PbI2", "no_motif"): ("w0", pytest.approx(0.25)), + ("PbI2O", "motif_1"): ("w1", pytest.approx(0.75)), + } + + +def test_rmcsetup_representative_reset_keeps_saved_representative_sources( + tmp_path, +): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + state.representative_selection = select_first_file_representatives( + state, + selection, + ) + state.solvent_handling = build_representative_solvent_outputs( + state, + _integrated_solvent_handling_settings(), + ) + source_paths = [ + Path(entry.source_file) + for entry in state.representative_selection.representative_entries + ] + tracked_outputs = [ + Path(entry.completed_pdb) for entry in state.solvent_handling.entries + ] + window = RMCSetupMainWindow(initial_project_dir=project_dir) + + reset = window._reset_representative_dependent_state( + confirm=False, + refresh=False, + clear_reason="test", + ) + + assert reset is True + assert window._project_source_state is not None + assert window._project_source_state.representative_selection is not None + assert all(path.is_file() for path in source_paths) + assert not any(path.exists() for path in tracked_outputs) + assert ( + load_solvent_handling_metadata( + state.rmcsetup_paths.solvent_handling_path + ) + is None + ) + + def test_rmcsetup_imported_full_solvent_representatives_mark_solvent_ready( tmp_path, ): @@ -4188,6 +4675,39 @@ def test_rmcsetup_ui_can_compute_packmol_plan_and_reload(tmp_path): ) +def test_rmcsetup_packmol_single_atom_uses_periodic_table_picker( + tmp_path, + monkeypatch, +): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + window = RMCSetupMainWindow(initial_project_dir=project_dir) + monkeypatch.setattr( + fullrmc_ui_module.PeriodicTableElementDialog, + "get_element_symbol", + lambda **_kwargs: "Cs", + ) + monkeypatch.setattr( + fullrmc_ui_module.QInputDialog, + "getItem", + lambda *_args, **_kwargs: ("solute", True), + ) + monkeypatch.setattr( + fullrmc_ui_module.QInputDialog, + "getText", + lambda *_args, **_kwargs: ("CES", True), + ) + + window._add_packmol_supplemental_atom_component() + + components = window._current_packmol_supplemental_components() + assert len(components) == 1 + assert components[0].element == "Cs" + assert components[0].residue_name == "CES" + assert window.packmol_supplemental_table.item(0, 2).text() == "Cs" + window.close() + + def test_rmcsetup_ui_packmol_preview_includes_single_atom_model_sources( tmp_path, ): diff --git a/tests/test_mdtrajectory_cli.py b/tests/test_mdtrajectory_cli.py index c772626..c283cf0 100644 --- a/tests/test_mdtrajectory_cli.py +++ b/tests/test_mdtrajectory_cli.py @@ -61,6 +61,24 @@ def _write_sample_ener(path: Path) -> None: ) +def _write_restart_overlap_xyz(path: Path) -> None: + path.write_text( + "1\n" + "i = 0, time = 0.0, E = -1.0\n" + "H 0.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 1.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 9.0 0.0 0.0\n" + "1\n" + "i = 2, time = 1.0, E = -1.0\n" + "H 2.0 0.0 0.0\n", + encoding="utf-8", + ) + + def test_workflow_supports_notebook_style_end_to_end_usage(tmp_path): trajectory_file = tmp_path / "traj.xyz" energy_file = tmp_path / "traj.ener" @@ -198,6 +216,265 @@ def test_mdtrajectory_cli_export_runs_complete_headless_workflow( ] +def test_mdtrajectory_cli_export_can_include_restart_duplicates( + tmp_path, + capsys, +): + trajectory_file = tmp_path / "traj.xyz" + _write_restart_overlap_xyz(trajectory_file) + + exit_code = mdtrajectory_main( + [ + "export", + str(trajectory_file), + "--include-restart-duplicates", + ] + ) + + captured = capsys.readouterr() + output_dir = tmp_path / "splitxyz_f0_t0fs" + metadata_payload = json.loads( + (output_dir / "mdtrajectory_export.json").read_text() + ) + + assert exit_code == 0 + assert "Restart duplicate frames: included" in captured.out + assert sorted(path.name for path in output_dir.glob("*.xyz")) == [ + "frame_0000.xyz", + "frame_0001.xyz", + "frame_0001_duplicate0001.xyz", + "frame_0002.xyz", + ] + assert ( + "H 1.0" + in (output_dir / "frame_0001_duplicate0001.xyz").read_text() + ) + assert "H 9.0" in (output_dir / "frame_0001.xyz").read_text() + assert metadata_payload["selection"]["include_restart_duplicates"] is True + + +def test_workflow_can_validate_exported_xyz_frame_mapping(tmp_path): + trajectory_file = tmp_path / "traj.xyz" + _write_sample_xyz(trajectory_file) + + workflow = MDTrajectoryWorkflow(trajectory_file=trajectory_file) + export = workflow.export_frames(use_cutoff=True, cutoff_fs=50.0) + + result = workflow.validate_export( + export.output_dir, + expect_contiguous=True, + ) + + assert result.passed + assert result.exported_files == 3 + assert result.validated_files == 3 + assert result.filename_index_min == 1 + assert result.filename_index_max == 3 + assert result.header_index_min == 1 + assert result.header_index_max == 3 + assert result.filename_header_offsets == {0: 3} + assert result.issue_counts == {} + + +def test_workflow_validation_accepts_purged_source_duplicate_conflicts( + tmp_path, +): + trajectory_file = tmp_path / "traj.xyz" + trajectory_file.write_text( + "1\n" + "i = 0, time = 0.0, E = -1.0\n" + "H 0.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 1.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 9.0 0.0 0.0\n" + "1\n" + "i = 2, time = 1.0, E = -1.0\n" + "H 2.0 0.0 0.0\n", + encoding="utf-8", + ) + + workflow = MDTrajectoryWorkflow(trajectory_file=trajectory_file) + export = workflow.export_frames() + + result = workflow.validate_export(export.output_dir) + strict_result = workflow.validate_export( + export.output_dir, + strict_source_duplicates=True, + ) + + assert result.passed + assert result.source_duplicate_indices == 1 + assert result.source_duplicate_conflicts == 1 + assert result.issue_counts == {} + assert strict_result.failure_count == 1 + assert not strict_result.passed + assert "H 9.0" in (export.output_dir / "frame_0001.xyz").read_text() + + +def test_workflow_validation_allows_identical_source_duplicates_by_default( + tmp_path, +): + trajectory_file = tmp_path / "traj.xyz" + trajectory_file.write_text( + "1\n" + "i = 0, time = 0.0, E = -1.0\n" + "H 0.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 1.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 1.0 0.0 0.0\n" + "1\n" + "i = 2, time = 1.0, E = -1.0\n" + "H 2.0 0.0 0.0\n", + encoding="utf-8", + ) + + workflow = MDTrajectoryWorkflow(trajectory_file=trajectory_file) + export = workflow.export_frames() + + result = workflow.validate_export(export.output_dir) + strict_result = workflow.validate_export( + export.output_dir, + strict_source_duplicates=True, + ) + + assert result.passed + assert result.source_duplicate_indices == 1 + assert result.source_duplicate_conflicts == 0 + assert result.issue_counts == {} + assert strict_result.failure_count == 1 + assert not strict_result.passed + + +def test_workflow_validation_rejects_export_that_keeps_earlier_overlap( + tmp_path, +): + trajectory_file = tmp_path / "traj.xyz" + trajectory_file.write_text( + "1\n" + "i = 0, time = 0.0, E = -1.0\n" + "H 0.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 1.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 9.0 0.0 0.0\n" + "1\n" + "i = 2, time = 1.0, E = -1.0\n" + "H 2.0 0.0 0.0\n", + encoding="utf-8", + ) + bad_export_dir = tmp_path / "bad_frames" + bad_export_dir.mkdir() + (bad_export_dir / "frame_0001.xyz").write_text( + "1\n" "i = 1, time = 0.5, E = -1.0\n" "H 1.0 0.0 0.0\n", + encoding="utf-8", + ) + + workflow = MDTrajectoryWorkflow(trajectory_file=trajectory_file) + result = workflow.validate_export(bad_export_dir) + + assert not result.passed + assert result.source_duplicate_conflicts == 1 + assert result.issue_counts == {"coordinate_mismatch": 1} + + +def test_mdtrajectory_cli_validate_export_reports_mapping_failures( + tmp_path, + capsys, +): + trajectory_file = tmp_path / "traj.xyz" + _write_sample_xyz(trajectory_file) + frame_dir = tmp_path / "frames" + frame_dir.mkdir() + (frame_dir / "frame_0001.xyz").write_text( + "2\n" + "i = 1, time = 50.0, E = -1.0\n" + "H 0.0 0.1 0.0\n" + "O 1.0 0.1 0.0\n", + encoding="utf-8", + ) + (frame_dir / "frame_0002.xyz").write_text( + "2\n" + "i = 1, time = 50.0, E = -1.0\n" + "H 0.0 0.1 0.0\n" + "O 1.0 0.1 0.0\n", + encoding="utf-8", + ) + + exit_code = mdtrajectory_main( + [ + "validate-export", + str(trajectory_file), + str(frame_dir), + ] + ) + + captured = capsys.readouterr() + + assert exit_code == 1 + assert "Export validation failed." in captured.out + assert "- filename_header_offset: 1" in captured.out + assert "- coordinate_mismatch: 1" in captured.out + assert "- duplicate_export_header_index: 1" in captured.out + + +def test_mdtrajectory_cli_validate_export_fails_empty_frame_directory( + tmp_path, + capsys, +): + trajectory_file = tmp_path / "traj.xyz" + _write_sample_xyz(trajectory_file) + frame_dir = tmp_path / "empty_frames" + frame_dir.mkdir() + + exit_code = mdtrajectory_main( + [ + "validate-export", + str(trajectory_file), + str(frame_dir), + ] + ) + + captured = capsys.readouterr() + + assert exit_code == 1 + assert "- no_exported_xyz_files: 1" in captured.out + + +def test_mdtrajectory_cli_suggest_cutoff_defaults_to_window_two( + tmp_path, + capsys, +): + trajectory_file = tmp_path / "traj.xyz" + energy_file = tmp_path / "traj.ener" + _write_sample_xyz(trajectory_file) + _write_sample_ener(energy_file) + + exit_code = mdtrajectory_main( + [ + "suggest-cutoff", + str(trajectory_file), + "--energy-file", + str(energy_file), + "--temp-target-k", + "300.0", + ] + ) + + captured = capsys.readouterr() + + assert exit_code == 0 + assert "Suggested cutoff: 50.000 fs" in captured.out + assert "Window: 2" in captured.out + + def test_saxshell_cli_forwards_to_mdtrajectory_subcommand( tmp_path, capsys, diff --git a/tests/test_mdtrajectory_cluster.py b/tests/test_mdtrajectory_cluster.py index 30d6bea..5365251 100644 --- a/tests/test_mdtrajectory_cluster.py +++ b/tests/test_mdtrajectory_cluster.py @@ -425,6 +425,135 @@ def test_smart_solvation_shells_keep_union_across_contiguous_pdb_frames( } == {10, 11} +def test_smart_solvation_shell_updates_scale_linearly_for_contiguous_runs( + tmp_path, + monkeypatch, +): + frames_dir = tmp_path / "splitpdb0001" + frames_dir.mkdir() + frame_count = 12 + for frame_index in range(frame_count): + if frame_index % 2 == 0: + residue10_y = 1.0 + residue11_y = 4.0 + else: + residue10_y = 4.0 + residue11_y = 1.0 + (frames_dir / f"frame_{frame_index:04d}.pdb").write_text( + "".join( + _smart_shell_frame_lines( + residue10_y=residue10_y, + residue11_y=residue11_y, + ) + ) + ) + + touched_frame_refs = 0 + original_apply = ( + cluster_module.ExtractedFrameFolderClusterAnalyzer._apply_smart_shell_union_to_run + ) + + def counted_apply(self, run, frame_entries, *, elements): + nonlocal touched_frame_refs + touched_frame_refs += len(run.frame_refs) + return original_apply( + self, + run, + frame_entries, + elements=elements, + ) + + monkeypatch.setattr( + cluster_module.ExtractedFrameFolderClusterAnalyzer, + "_apply_smart_shell_union_to_run", + counted_apply, + ) + analyzer = ExtractedFrameFolderClusterAnalyzer( + frames_dir=frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoffs_def=PAIR_CUTOFFS, + smart_solvation_shells=True, + ) + + export = analyzer.export_cluster_pdbs( + tmp_path / "clusters_from_folder", + shell_levels=(1,), + include_shell_levels=(0, 1), + ) + + assert touched_frame_refs == frame_count + assert len(export.frame_results) == frame_count + for path in export.written_files: + structure = PDBStructure.from_file(path) + assert { + atom.residue_number + for atom in structure.atoms + if atom.residue_name == "WAT" + } == {10, 11} + + +def test_smart_solvation_shell_resume_preserves_deferred_unions(tmp_path): + frames_dir = tmp_path / "splitpdb0001" + frames_dir.mkdir() + for frame_index in range(4): + if frame_index % 2 == 0: + residue10_y = 1.0 + residue11_y = 4.0 + else: + residue10_y = 4.0 + residue11_y = 1.0 + (frames_dir / f"frame_{frame_index:04d}.pdb").write_text( + "".join( + _smart_shell_frame_lines( + residue10_y=residue10_y, + residue11_y=residue11_y, + ) + ) + ) + output_dir = tmp_path / "clusters_from_folder" + analyzer = ExtractedFrameFolderClusterAnalyzer( + frames_dir=frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoffs_def=PAIR_CUTOFFS, + smart_solvation_shells=True, + ) + + def stop_after_second_frame(processed, total, frame_label): + if processed >= 2 and frame_label != "resume": + raise RuntimeError("stop after two frames") + + with pytest.raises(RuntimeError, match="stop after two frames"): + analyzer.export_cluster_files( + output_dir, + shell_levels=(1,), + include_shell_levels=(0, 1), + progress_callback=stop_after_second_frame, + ) + + interrupted = json.loads( + (output_dir / "cluster_extraction_metadata.json").read_text() + ) + assert interrupted["state"] == "failed" + assert interrupted["progress"]["completed_frames"] == 2 + + resumed = analyzer.export_cluster_files( + output_dir, + shell_levels=(1,), + include_shell_levels=(0, 1), + ) + + assert resumed.resumed + assert resumed.previously_completed_frames == 2 + assert resumed.newly_processed_frames == 2 + for path in resumed.written_files: + structure = PDBStructure.from_file(path) + assert { + atom.residue_number + for atom in structure.atoms + if atom.residue_name == "WAT" + } == {10, 11} + + def test_legacy_solvation_shells_preserve_per_frame_pdb_cutoffs(tmp_path): frames_dir = tmp_path / "splitpdb0001" frames_dir.mkdir() diff --git a/tests/test_mdtrajectory_manager.py b/tests/test_mdtrajectory_manager.py index 32ce947..93482b9 100644 --- a/tests/test_mdtrajectory_manager.py +++ b/tests/test_mdtrajectory_manager.py @@ -1,5 +1,7 @@ import pytest +from saxshell.mdtrajectory.frame.base import FrameRecord +from saxshell.mdtrajectory.frame.exporters import export_xyz_frames from saxshell.mdtrajectory.frame.manager import TrajectoryManager @@ -238,3 +240,170 @@ def test_export_frames_keeps_exact_half_fs_cutoff_boundary(tmp_path): assert preview.first_time_fs == pytest.approx(497.5) assert written_files[0].name == "frame_0995.xyz" assert written_files[-1].name == "frame_1004.xyz" + + +def test_cp2k_restart_overlap_frames_keep_later_source_index_occurrence( + tmp_path, +): + trajectory_file = tmp_path / "traj.xyz" + trajectory_file.write_text( + "1\n" + "i = 0, time = 0.0, E = -1.0\n" + "H 0.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 1.0 0.0 0.0\n" + "1\n" + "i = 2, time = 1.0, E = -1.0\n" + "H 2.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 9.0 0.0 0.0\n" + "1\n" + "i = 2, time = 1.0, E = -1.0\n" + "H 9.0 0.0 0.0\n" + "1\n" + "i = 3, time = 1.5, E = -1.0\n" + "H 3.0 0.0 0.0\n", + encoding="utf-8", + ) + + manager = TrajectoryManager(input_file=trajectory_file) + summary = manager.inspect() + preview = manager.preview_selection(min_time_fs=0.5) + written_files = manager.export_frames( + output_dir=tmp_path / "frames", + min_time_fs=0.5, + ) + + assert summary["n_frames"] == 4 + assert summary["raw_frames"] == 6 + assert summary["duplicate_source_frames"] == 2 + assert preview.total_frames == 4 + assert preview.selected_frames == 3 + assert preview.first_frame_index == 1 + assert preview.last_frame_index == 3 + assert [path.name for path in written_files] == [ + "frame_0001.xyz", + "frame_0002.xyz", + "frame_0003.xyz", + ] + assert "H 9.0" in written_files[0].read_text() + assert "H 9.0" in written_files[1].read_text() + assert "H 3.0" in written_files[2].read_text() + + +def test_cp2k_restart_overlap_frames_can_include_duplicate_occurrences( + tmp_path, +): + trajectory_file = tmp_path / "traj.xyz" + trajectory_file.write_text( + "1\n" + "i = 0, time = 0.0, E = -1.0\n" + "H 0.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 1.0 0.0 0.0\n" + "1\n" + "i = 2, time = 1.0, E = -1.0\n" + "H 2.0 0.0 0.0\n" + "1\n" + "i = 1, time = 0.5, E = -1.0\n" + "H 9.1 0.0 0.0\n" + "1\n" + "i = 2, time = 1.0, E = -1.0\n" + "H 9.2 0.0 0.0\n" + "1\n" + "i = 3, time = 1.5, E = -1.0\n" + "H 3.0 0.0 0.0\n", + encoding="utf-8", + ) + + manager = TrajectoryManager( + input_file=trajectory_file, + include_restart_duplicates=True, + ) + summary = manager.inspect() + preview = manager.preview_selection(min_time_fs=0.5) + written_files = manager.export_frames( + output_dir=tmp_path / "frames", + min_time_fs=0.5, + ) + + assert summary["n_frames"] == 6 + assert summary["raw_frames"] == 6 + assert summary["duplicate_source_frames"] == 2 + assert summary["include_restart_duplicates"] is True + assert preview.total_frames == 6 + assert preview.selected_frames == 5 + assert [path.name for path in written_files] == [ + "frame_0001_duplicate0001.xyz", + "frame_0002_duplicate0001.xyz", + "frame_0001.xyz", + "frame_0002.xyz", + "frame_0003.xyz", + ] + assert ( + "H 1.0" + in (tmp_path / "frames" / "frame_0001_duplicate0001.xyz").read_text() + ) + assert ( + "H 9.1" in (tmp_path / "frames" / "frame_0001.xyz").read_text() + ) + assert ( + "H 2.0" + in (tmp_path / "frames" / "frame_0002_duplicate0001.xyz").read_text() + ) + assert ( + "H 9.2" in (tmp_path / "frames" / "frame_0002.xyz").read_text() + ) + + +def test_export_rejects_xyz_header_index_mismatch(tmp_path): + frame = FrameRecord( + frame_index=2559, + file_type="xyz", + atom_count=1, + lines=[ + "i = 2501, time = 1250.5, E = -1.0\n", + "H 0.0 0.0 0.0\n", + ], + time_fs=1250.5, + ) + + with pytest.raises( + ValueError, + match="header reports i = 2501", + ): + export_xyz_frames([frame], tmp_path / "frames") + + +def test_export_rejects_duplicate_xyz_output_names(tmp_path): + frames = [ + FrameRecord( + frame_index=7, + file_type="xyz", + atom_count=1, + lines=[ + "i = 7, time = 3.5, E = -1.0\n", + "H 0.0 0.0 0.0\n", + ], + time_fs=3.5, + ), + FrameRecord( + frame_index=7, + file_type="xyz", + atom_count=1, + lines=[ + "i = 7, time = 3.5, E = -1.0\n", + "H 1.0 0.0 0.0\n", + ], + time_fs=3.5, + ), + ] + + with pytest.raises( + ValueError, + match="same output file", + ): + export_xyz_frames(frames, tmp_path / "frames") diff --git a/tests/test_mdtrajectory_ui.py b/tests/test_mdtrajectory_ui.py index 2944ad3..f9d61c2 100644 --- a/tests/test_mdtrajectory_ui.py +++ b/tests/test_mdtrajectory_ui.py @@ -4,10 +4,19 @@ import numpy as np import pytest from matplotlib.backends.backend_qtagg import NavigationToolbar2QT +from PySide6.QtCore import Qt from PySide6.QtWidgets import QApplication +import saxshell.mdtrajectory.ui.batch_queue_window as md_batch_queue_module from saxshell.mdtrajectory.frame.cp2k_ener import CP2KEnergyData from saxshell.mdtrajectory.frame.manager import FrameSelectionPreview +from saxshell.mdtrajectory.ui.batch_queue_window import ( + DEFAULT_TIME_CUTOFF_FS, + MDTrajectoryBatchJob, + MDTrajectoryBatchQueueWindow, + MDTrajectoryBatchResult, + MDTrajectoryBatchWorker, +) from saxshell.mdtrajectory.ui.cutoff_panel import CutoffPanel from saxshell.mdtrajectory.ui.export_panel import ExportPanel from saxshell.mdtrajectory.ui.main_window import ( @@ -26,6 +35,83 @@ def qapp(): yield app +def _write_batch_xyz(path: Path) -> None: + path.write_text( + "1\n" + "i = 0, time = 0.0, E = -1.0\n" + "Pb 0.0 0.0 0.0\n" + "1\n" + "i = 1, time = 500.0, E = -1.0\n" + "Pb 0.1 0.0 0.0\n" + "1\n" + "i = 2, time = 1500.0, E = -1.0\n" + "Pb 0.2 0.0 0.0\n", + encoding="utf-8", + ) + + +def _write_batch_restart_overlap_xyz(path: Path) -> None: + path.write_text( + "1\n" + "i = 0, time = 0.0, E = -1.0\n" + "Pb 0.0 0.0 0.0\n" + "1\n" + "i = 1, time = 500.0, E = -1.0\n" + "Pb 0.1 0.0 0.0\n" + "1\n" + "i = 2, time = 1500.0, E = -1.0\n" + "Pb 0.2 0.0 0.0\n" + "1\n" + "i = 1, time = 500.0, E = -1.0\n" + "Pb 9.1 0.0 0.0\n" + "1\n" + "i = 2, time = 1500.0, E = -1.0\n" + "Pb 9.2 0.0 0.0\n" + "1\n" + "i = 3, time = 2000.0, E = -1.0\n" + "Pb 0.3 0.0 0.0\n", + encoding="utf-8", + ) + + +def _write_batch_energy(path: Path) -> None: + path.write_text( + "# step time kinetic temperature potential\n" + "1 0.0 1.0 300.0 -10.0\n" + "2 500.0 1.0 301.0 -10.0\n" + "3 1500.0 1.0 300.5 -10.0\n", + encoding="utf-8", + ) + + +def _create_mdtrajectory_batch_project( + tmp_path: Path, + name: str, +) -> tuple[Path, Path, Path]: + manager = SAXSProjectManager() + project_dir = tmp_path / name + settings = manager.create_project(project_dir) + trajectory_file = project_dir / "traj.xyz" + energy_file = project_dir / "traj.ener" + _write_batch_xyz(trajectory_file) + _write_batch_energy(energy_file) + settings.trajectory_file = str(trajectory_file) + settings.energy_file = str(energy_file) + manager.save_project(settings) + return project_dir, trajectory_file, energy_file + + +def _create_mdtrajectory_overlap_batch_project( + tmp_path: Path, + name: str, +) -> tuple[Path, Path, Path]: + project_dir, trajectory_file, energy_file = ( + _create_mdtrajectory_batch_project(tmp_path, name) + ) + _write_batch_restart_overlap_xyz(trajectory_file) + return project_dir, trajectory_file, energy_file + + def test_export_panel_suggest_output_dir_keeps_manual_override( qapp, tmp_path, @@ -65,6 +151,17 @@ def test_export_panel_post_cutoff_stride_controls_follow_cutoff_toggle(qapp): assert panel.get_post_cutoff_stride() == 3 +def test_export_panel_restart_duplicate_option_defaults_off(qapp): + del qapp + panel = ExportPanel() + + assert not panel.include_restart_duplicates() + + panel.include_restart_duplicates_box.setChecked(True) + + assert panel.include_restart_duplicates() + + def test_export_panel_progress_methods_update_ui(qapp): del qapp panel = ExportPanel() @@ -119,6 +216,7 @@ def test_cutoff_panel_load_energy_draws_target_temperature_line( assert horizontal_lines assert panel.temp_target_spin.toolTip() assert panel.cutoff_spin.toolTip() + assert panel.window_spin.value() == 2 def test_cutoff_panel_uses_matplotlib_navigation_toolbar( @@ -435,3 +533,383 @@ def fake_start_export_worker(**kwargs): window.export_panel.log_box.toPlainText() ) window.close() + + +def test_mdtrajectory_batch_queue_prefills_current_project_defaults( + qapp, + tmp_path, +): + del qapp + project_dir, trajectory_file, energy_file = ( + _create_mdtrajectory_batch_project(tmp_path, "project_a") + ) + + window = MDTrajectoryBatchQueueWindow(initial_project_dir=project_dir) + + assert window.queue_list.count() == 1 + widget = next(iter(window._widgets_by_id.values())) + assert widget.project_dir_edit.text() == str(project_dir.resolve()) + assert widget.trajectory_file_edit.text() == str(trajectory_file.resolve()) + assert widget.energy_file_edit.text() == str(energy_file.resolve()) + assert widget.output_dir_edit.text() == "" + assert widget.cutoff_spin.value() == pytest.approx(DEFAULT_TIME_CUTOFF_FS) + assert not widget.include_restart_duplicates_box.isChecked() + window.close() + + +def test_mdtrajectory_batch_queue_exposes_and_uses_editable_output_folder( + qapp, + tmp_path, +): + del qapp + project_dir, _trajectory_file, _energy_file = ( + _create_mdtrajectory_batch_project(tmp_path, "project_a") + ) + custom_output_dir = tmp_path / "custom_splitxyz" + window = MDTrajectoryBatchQueueWindow(initial_project_dir=project_dir) + widget = next(iter(window._widgets_by_id.values())) + + widget.preview_selection() + + assert widget.output_dir_edit.text().endswith("splitxyz_f2_t1500fs") + + widget.output_dir_edit.setText(str(custom_output_dir)) + jobs = [job for _item_id, job in window.queue_jobs_in_order()] + + assert jobs[0].output_dir == custom_output_dir.resolve() + window.close() + + +def test_mdtrajectory_batch_queue_exposes_restart_duplicate_option( + qapp, + tmp_path, +): + del qapp + project_dir, _trajectory_file, _energy_file = ( + _create_mdtrajectory_batch_project(tmp_path, "project_a") + ) + window = MDTrajectoryBatchQueueWindow(initial_project_dir=project_dir) + widget = next(iter(window._widgets_by_id.values())) + + widget.include_restart_duplicates_box.setChecked(True) + jobs = [job for _item_id, job in window.queue_jobs_in_order()] + + assert jobs[0].include_restart_duplicates + window.close() + + +def test_mdtrajectory_batch_queue_adds_multiple_selected_projects( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_a, trajectory_a, energy_a = _create_mdtrajectory_batch_project( + tmp_path, + "project_a", + ) + project_b, trajectory_b, energy_b = _create_mdtrajectory_batch_project( + tmp_path, + "project_b", + ) + monkeypatch.setattr( + md_batch_queue_module, + "_choose_existing_directories", + lambda *_args, **_kwargs: (project_a, project_b), + ) + window = MDTrajectoryBatchQueueWindow() + + window._choose_projects_to_add() + + assert window.queue_list.count() == 2 + jobs = [job for _item_id, job in window.queue_jobs_in_order()] + assert [job.project_dir for job in jobs] == [ + project_a.resolve(), + project_b.resolve(), + ] + assert [job.trajectory_file for job in jobs] == [ + trajectory_a.resolve(), + trajectory_b.resolve(), + ] + assert [job.energy_file for job in jobs] == [ + energy_a.resolve(), + energy_b.resolve(), + ] + assert [job.cutoff_fs for job in jobs] == [ + DEFAULT_TIME_CUTOFF_FS, + DEFAULT_TIME_CUTOFF_FS, + ] + window.close() + + +def test_mdtrajectory_batch_worker_exports_and_registers_each_project( + qapp, + tmp_path, +): + del qapp + manager = SAXSProjectManager() + project_a, trajectory_a, energy_a = _create_mdtrajectory_batch_project( + tmp_path, + "project_a", + ) + project_b, trajectory_b, energy_b = _create_mdtrajectory_batch_project( + tmp_path, + "project_b", + ) + worker = MDTrajectoryBatchWorker( + [ + ( + "item-a", + MDTrajectoryBatchJob( + project_dir=project_a, + trajectory_file=trajectory_a, + topology_file=None, + energy_file=energy_a, + cutoff_fs=DEFAULT_TIME_CUTOFF_FS, + ), + ), + ( + "item-b", + MDTrajectoryBatchJob( + project_dir=project_b, + trajectory_file=trajectory_b, + topology_file=None, + energy_file=energy_b, + cutoff_fs=DEFAULT_TIME_CUTOFF_FS, + ), + ), + ] + ) + finished: list[list[object]] = [] + failures: list[tuple[str, str]] = [] + worker.finished.connect(finished.append) + worker.failed.connect( + lambda item_id, message: failures.append((item_id, message)) + ) + + worker.run() + + assert failures == [] + assert len(finished) == 1 + results = finished[0] + assert len(results) == 2 + for result, project_dir, trajectory_file, energy_file in zip( + results, + (project_a, project_b), + (trajectory_a, trajectory_b), + (energy_a, energy_b), + strict=True, + ): + saved_settings = manager.load_project(project_dir) + assert saved_settings.resolved_frames_dir == ( + result.output_dir.resolve() + ) + assert saved_settings.resolved_trajectory_file == ( + trajectory_file.resolve() + ) + assert saved_settings.resolved_energy_file == energy_file.resolve() + assert saved_settings.frames_dir_snapshot is not None + assert result.written_count == 1 + assert result.selected_frames == 1 + assert result.output_dir.name == "splitxyz_f2_t1500fs" + assert result.metadata_file is not None + assert result.metadata_file.is_file() + assert (result.output_dir / "frame_0002.xyz").is_file() + + +def test_mdtrajectory_batch_worker_honors_custom_output_folder( + qapp, + tmp_path, +): + del qapp + manager = SAXSProjectManager() + project_dir, trajectory_file, energy_file = ( + _create_mdtrajectory_batch_project(tmp_path, "project_custom_output") + ) + custom_output_dir = tmp_path / "queued_xyz_output" + worker = MDTrajectoryBatchWorker( + [ + ( + "item-custom", + MDTrajectoryBatchJob( + project_dir=project_dir, + trajectory_file=trajectory_file, + topology_file=None, + energy_file=energy_file, + output_dir=custom_output_dir, + cutoff_fs=DEFAULT_TIME_CUTOFF_FS, + ), + ) + ] + ) + finished: list[list[object]] = [] + failures: list[tuple[str, str]] = [] + worker.finished.connect(finished.append) + worker.failed.connect( + lambda item_id, message: failures.append((item_id, message)) + ) + + worker.run() + + assert failures == [] + result = finished[0][0] + assert result.output_dir == custom_output_dir + assert (custom_output_dir / "frame_0002.xyz").is_file() + saved_settings = manager.load_project(project_dir) + assert saved_settings.resolved_frames_dir == custom_output_dir.resolve() + + +def test_mdtrajectory_batch_worker_uses_source_indices_without_validation_pass( + qapp, + tmp_path, + monkeypatch, +): + del qapp + manager = SAXSProjectManager() + project_dir, trajectory_file, energy_file = ( + _create_mdtrajectory_overlap_batch_project(tmp_path, "project_overlap") + ) + + def fail_validation(*_args, **_kwargs): + raise AssertionError("Batch queue must not run export assertions") + + monkeypatch.setattr( + md_batch_queue_module.MDTrajectoryWorkflow, + "validate_export", + fail_validation, + ) + worker = MDTrajectoryBatchWorker( + [ + ( + "item-overlap", + MDTrajectoryBatchJob( + project_dir=project_dir, + trajectory_file=trajectory_file, + topology_file=None, + energy_file=energy_file, + cutoff_fs=DEFAULT_TIME_CUTOFF_FS, + ), + ) + ] + ) + finished: list[list[object]] = [] + failures: list[tuple[str, str]] = [] + log_messages: list[str] = [] + worker.finished.connect(finished.append) + worker.failed.connect( + lambda item_id, message: failures.append((item_id, message)) + ) + worker.log.connect(log_messages.append) + + worker.run() + + assert failures == [] + assert len(finished) == 1 + result = finished[0][0] + assert result.written_count == 2 + assert result.selected_frames == 2 + assert result.output_dir.name == "splitxyz_f2_t1500fs" + assert (result.output_dir / "frame_0002.xyz").is_file() + assert (result.output_dir / "frame_0003.xyz").is_file() + assert not (result.output_dir / "frame_0004.xyz").exists() + assert "i = 2" in (result.output_dir / "frame_0002.xyz").read_text() + assert "i = 3" in (result.output_dir / "frame_0003.xyz").read_text() + assert "9.2" in (result.output_dir / "frame_0002.xyz").read_text() + assert any( + "Skipped 2 duplicate source frame(s)" in message + for message in log_messages + ) + saved_settings = manager.load_project(project_dir) + assert saved_settings.resolved_frames_dir == result.output_dir.resolve() + + +def test_mdtrajectory_batch_worker_can_include_restart_duplicates( + qapp, + tmp_path, +): + del qapp + project_dir, trajectory_file, energy_file = ( + _create_mdtrajectory_overlap_batch_project( + tmp_path, + "project_overlap_include", + ) + ) + worker = MDTrajectoryBatchWorker( + [ + ( + "item-overlap", + MDTrajectoryBatchJob( + project_dir=project_dir, + trajectory_file=trajectory_file, + topology_file=None, + energy_file=energy_file, + cutoff_fs=DEFAULT_TIME_CUTOFF_FS, + include_restart_duplicates=True, + ), + ) + ] + ) + finished: list[list[object]] = [] + failures: list[tuple[str, str]] = [] + log_messages: list[str] = [] + worker.finished.connect(finished.append) + worker.failed.connect( + lambda item_id, message: failures.append((item_id, message)) + ) + worker.log.connect(log_messages.append) + + worker.run() + + assert failures == [] + result = finished[0][0] + assert result.include_restart_duplicates + assert result.written_count == 3 + assert result.selected_frames == 3 + assert (result.output_dir / "frame_0002_duplicate0001.xyz").is_file() + assert (result.output_dir / "frame_0002.xyz").is_file() + assert (result.output_dir / "frame_0003.xyz").is_file() + assert ( + "0.2" + in (result.output_dir / "frame_0002_duplicate0001.xyz").read_text() + ) + assert "9.2" in (result.output_dir / "frame_0002.xyz").read_text() + assert any( + "Included 2 duplicate source frame(s)" in message + for message in log_messages + ) + + +def test_mdtrajectory_batch_queue_emits_registered_frames_folder( + qapp, + tmp_path, +): + del qapp + project_dir, _trajectory_file, _energy_file = ( + _create_mdtrajectory_batch_project(tmp_path, "project_a") + ) + output_dir = tmp_path / "splitxyz_f2_t1500fs" + output_dir.mkdir() + window = MDTrajectoryBatchQueueWindow(initial_project_dir=project_dir) + updates: list[dict[str, object]] = [] + window.project_paths_registered.connect(updates.append) + item_id = str(window.queue_list.item(0).data(Qt.ItemDataRole.UserRole)) + + window._on_item_finished( + item_id, + MDTrajectoryBatchResult( + project_dir=project_dir.resolve(), + output_dir=output_dir.resolve(), + written_count=1, + selected_frames=1, + cutoff_fs=DEFAULT_TIME_CUTOFF_FS, + metadata_file=None, + ), + ) + + assert updates == [ + { + "project_dir": project_dir.resolve(), + "frames_dir": output_dir.resolve(), + } + ] + window.close() diff --git a/tests/test_pdfsetup.py b/tests/test_pdfsetup.py index e648127..dc7ef8d 100644 --- a/tests/test_pdfsetup.py +++ b/tests/test_pdfsetup.py @@ -7,19 +7,38 @@ import numpy as np import pytest -from PySide6.QtWidgets import QApplication +from PySide6.QtCore import Qt +from PySide6.QtWidgets import QAbstractItemView, QApplication, QWidget +import saxshell.pdf.debyer.ui.batch_queue_window as batch_queue_module +import saxshell.pdf.debyer.workflow as debyer_workflow import saxshell.pdfsetup as pdfsetup_module from saxshell import saxshell as saxshell_module -from saxshell.pdf.debyer.ui.main_window import DebyerPDFMainWindow +from saxshell.pdf.debyer.ui.batch_queue_window import ( + DebyerPDFBatchItem, + DebyerPDFBatchItemWidget, + DebyerPDFBatchQueueWindow, + DebyerPDFBatchWorker, + DebyerPDFExistingPartialsJob, + DebyerPDFExistingPartialsWorker, +) +from saxshell.pdf.debyer.ui.main_window import ( + DebyerPDFMainWindow, + DebyerPDFWorker, +) from saxshell.pdf.debyer.workflow import ( + DebyerPDFCalculation, DebyerPDFSettings, DebyerPDFWorkflow, calculate_number_density, + compute_experimental_fit_metrics, convert_distribution_values, + fit_coordination_peak_from_r, list_saved_debyer_calculations, load_debyer_calculation, + parse_debyer_output_file, ) +from saxshell.saxs.project_manager import SAXSProjectManager @pytest.fixture(scope="module") @@ -31,10 +50,11 @@ def qapp(): yield app -def _write_fake_debyer(path: Path) -> Path: +def _write_fake_debyer(path: Path, *, sleep_seconds: float = 0.0) -> Path: script = """#!/usr/bin/env python3 import re import sys +import time from pathlib import Path import numpy as np @@ -63,6 +83,9 @@ def main(argv): digits = re.findall(r"(\\d+)", input_file.stem) frame_index = int(digits[-1]) if digits else 0 scale = 1.0 + frame_index + sleep_seconds = __SLEEP_SECONDS__ + if sleep_seconds > 0.0: + time.sleep(sleep_seconds) r_values = np.array([0.5, 1.0, 1.5], dtype=float) partial_pbpb = scale * np.array([0.10, 0.12, 0.14], dtype=float) partial_pbi = scale * np.array([0.55, 0.60, 0.65], dtype=float) @@ -92,25 +115,82 @@ def main(argv): if __name__ == "__main__": raise SystemExit(main(sys.argv)) """ + script = script.replace("__SLEEP_SECONDS__", repr(float(sleep_seconds))) path.write_text(script, encoding="utf-8") path.chmod(path.stat().st_mode | stat.S_IXUSR) return path -def _build_frames_dir(tmp_path: Path) -> Path: +def _build_frames_dir( + tmp_path: Path, + *, + frame_count: int = 2, + xyz: str | None = None, +) -> Path: frames_dir = tmp_path / "splitxyz_f0_t0fs" - frames_dir.mkdir() - xyz = ( - "3\n" "frame\n" "Pb 0.0 0.0 0.0\n" "I 2.0 0.0 0.0\n" "I 0.0 2.0 0.0\n" + frames_dir.mkdir(parents=True) + xyz_text = ( + xyz + or "3\n" + "frame\n" + "Pb 0.0 0.0 0.0\n" + "I 2.0 0.0 0.0\n" + "I 0.0 2.0 0.0\n" ) - for index in range(2): + for index in range(frame_count): (frames_dir / f"frame_{index:04d}.xyz").write_text( - xyz, + xyz_text, + encoding="utf-8", + ) + return frames_dir + + +def _pdb_atom_line( + serial: int, + atom_name: str, + residue_name: str, + residue_id: int, + element: str, + x: float, + y: float, + z: float, +) -> str: + return ( + f"ATOM {serial:5d} {atom_name:<4s} {residue_name:>3s} " + f"A{residue_id:4d} {x:8.3f}{y:8.3f}{z:8.3f}" + f" 1.00 0.00 {element:>2s}\n" + ) + + +def _build_pdb_frames_dir(tmp_path: Path, *, frame_count: int = 2) -> Path: + frames_dir = tmp_path / "splitpdb_f0_t0fs" + frames_dir.mkdir(parents=True) + pdb_text = "".join( + [ + _pdb_atom_line(1, "PB", "PER", 1, "Pb", 0.0, 0.0, 0.0), + _pdb_atom_line(2, "I1", "PER", 1, "I", 2.0, 0.0, 0.0), + _pdb_atom_line(3, "O1", "DMS", 2, "O", 0.0, 2.0, 0.0), + _pdb_atom_line(4, "C1", "DMS", 2, "C", 0.0, 0.0, 2.0), + "END\n", + ] + ) + for index in range(frame_count): + (frames_dir / f"frame_{index:04d}.pdb").write_text( + pdb_text, encoding="utf-8", ) return frames_dir +def _write_pbc_source_file(frames_dir: Path, token: str = "12x10x8") -> Path: + source = frames_dir.parent / f"sample_pbc_{token}-pos-1.xyz" + source.write_text( + "1\nsource\nPb 0.0 0.0 0.0\n", + encoding="utf-8", + ) + return source + + def test_convert_distribution_values_from_pdf_mode(): r_values = np.array([0.5, 1.0, 1.5], dtype=float) g_values = np.array([1.0, 1.1, 1.2], dtype=float) @@ -135,6 +215,313 @@ def test_convert_distribution_values_from_pdf_mode(): assert np.allclose(converted_g, expected_g) +def test_infers_default_solute_elements_from_frame_elements(): + assert debyer_workflow.infer_default_solute_elements( + {"Pb": 1, "I": 2} + ) == ("Pb", "I") + assert debyer_workflow.infer_default_solute_elements( + {"Cs": 1, "Pb": 1, "I": 3} + ) == ("Cs", "Pb", "I") + assert ( + debyer_workflow.infer_default_solute_elements({"Na": 1, "Cl": 1}) == () + ) + + +def test_pdfsetup_rejects_pdb_frame_folders(tmp_path): + frames_dir = _build_pdb_frames_dir(tmp_path) + + with pytest.raises(ValueError, match="require XYZ frame files"): + debyer_workflow.inspect_frames_dir(frames_dir) + + +def test_pdf_batch_settings_module_inspects_xyz_defaults(qapp, tmp_path): + project_dir = tmp_path / "project" + frames_dir = _build_frames_dir( + tmp_path, + xyz=( + "3\n" + "frame\n" + "Cs 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "I 0.0 2.0 0.0\n" + ), + ) + _write_pbc_source_file(frames_dir) + widget = DebyerPDFBatchItemWidget( + DebyerPDFBatchItem( + item_id="batch-item", + project_dir=project_dir, + frames_dir=frames_dir, + ) + ) + widget.to_edit.setText("25") + + widget.inspect_frames() + settings = widget.settings() + + assert settings.project_dir == project_dir.resolve() + assert settings.frames_dir == frames_dir.resolve() + assert settings.filename_prefix == frames_dir.name + assert settings.atom_count == 3 + assert settings.box_dimensions == pytest.approx((12.0, 10.0, 8.0)) + assert settings.to_value == pytest.approx(4.0) + assert settings.solute_elements == ("Cs", "Pb", "I") + assert "Detected XYZ frames" in widget.inspection_summary_label.text() + + +def test_pdf_batch_queue_window_keeps_collapsible_reorderable_items( + qapp, + tmp_path, +): + first_frames_dir = _build_frames_dir(tmp_path / "first") + second_frames_dir = _build_frames_dir(tmp_path / "second") + window = DebyerPDFBatchQueueWindow() + + first = window.add_queue_item( + DebyerPDFBatchItem( + item_id="first", + project_dir=tmp_path / "first_project", + frames_dir=first_frames_dir, + filename_prefix="first_pdf", + box_dimensions=(10.0, 10.0, 10.0), + atom_count=3, + ) + ) + second = window.add_queue_item( + DebyerPDFBatchItem( + item_id="second", + project_dir=tmp_path / "second_project", + frames_dir=second_frames_dir, + filename_prefix="second_pdf", + box_dimensions=(11.0, 11.0, 11.0), + atom_count=3, + ) + ) + + assert ( + window.queue_list.dragDropMode() + == QAbstractItemView.DragDropMode.InternalMove + ) + assert first.settings_group.isHidden() + assert second.settings_group.isHidden() + first._set_settings_visible(True) + assert not first.settings_group.isHidden() + first._set_settings_visible(False) + assert first.settings_group.isHidden() + window.queue_list.setCurrentItem(window.queue_list.item(0)) + window._refresh_item_selection_styles() + assert first.header_frame.property("selected") is True + assert second.header_frame.property("selected") is False + assert [ + item_id for item_id, _settings in window.queue_settings_in_order() + ] == [ + "first", + "second", + ] + assert second.item().display_name() == "second_project" + + +def test_pdf_batch_queue_adds_multiple_selected_project_folders( + qapp, + tmp_path, + monkeypatch, +): + project_a = tmp_path / "project_a" + project_b = tmp_path / "project_b" + project_a.mkdir() + project_b.mkdir() + window = DebyerPDFBatchQueueWindow() + monkeypatch.setattr( + batch_queue_module, + "_choose_existing_directories", + lambda *_args, **_kwargs: (project_a.resolve(), project_b.resolve()), + ) + + window._choose_project_to_add() + + assert window.queue_list.count() == 2 + project_dirs = [] + for row in range(window.queue_list.count()): + item_id = str( + window.queue_list.item(row).data(Qt.ItemDataRole.UserRole) + ) + project_dirs.append(window._widgets_by_id[item_id].item().project_dir) + assert project_dirs == [project_a.resolve(), project_b.resolve()] + + +def test_pdf_batch_queue_prefills_project_debyer_defaults( + qapp, + tmp_path, + monkeypatch, +): + frames_dir = _build_frames_dir( + tmp_path / "frames", + xyz=( + "3\n" + "frame\n" + "Cs 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "I 0.0 2.0 0.0\n" + ), + ) + project_dir = tmp_path / "project_defaults" + manager = SAXSProjectManager() + settings = manager.create_project(project_dir) + settings.frames_dir = str(frames_dir) + project_file = manager.save_project(settings) + payload = json.loads(project_file.read_text(encoding="utf-8")) + payload["debyer_pdf_settings"] = { + "filename_prefix": "stored_pdf", + "mode": "PDF", + "from_value": 0.7, + "to_value": 3.5, + "step_value": 0.02, + "box_dimensions": [12.0, 10.0, 8.0], + "atom_count": 3, + "solute_elements": ["Cs", "Pb", "I"], + "parallel_jobs": 2, + } + project_file.write_text(json.dumps(payload), encoding="utf-8") + window = DebyerPDFBatchQueueWindow() + monkeypatch.setattr( + batch_queue_module, + "_choose_existing_directories", + lambda *_args, **_kwargs: (project_dir.resolve(),), + ) + + window._choose_project_to_add() + + assert window.queue_list.count() == 1 + item_id = str(window.queue_list.item(0).data(Qt.ItemDataRole.UserRole)) + widget = window._widgets_by_id[item_id] + queued_settings = widget.settings() + assert queued_settings.project_dir == project_dir.resolve() + assert queued_settings.frames_dir == frames_dir.resolve() + assert queued_settings.filename_prefix == "stored_pdf" + assert queued_settings.from_value == pytest.approx(0.7) + assert queued_settings.to_value == pytest.approx(3.5) + assert queued_settings.step_value == pytest.approx(0.02) + assert queued_settings.box_dimensions == pytest.approx((12.0, 10.0, 8.0)) + assert queued_settings.atom_count == 3 + assert queued_settings.solute_elements == ("Cs", "Pb", "I") + assert queued_settings.max_parallel_jobs == 2 + + +def test_pdf_batch_queue_append_mode_uses_project_solute_jobs( + qapp, + tmp_path, +): + project_dir = tmp_path / "project" + window = DebyerPDFBatchQueueWindow() + item_widget = window.add_queue_item( + DebyerPDFBatchItem( + item_id="existing", + project_dir=project_dir, + solute_elements=("Pb",), + ) + ) + append_index = window.queue_mode_combo.findData("append_grouped") + + window.queue_mode_combo.setCurrentIndex(append_index) + jobs = window.existing_partials_jobs_in_order() + + assert window.run_button.text() == "Append Grouped Partial Columns" + assert window.add_frames_button.isEnabled() is False + assert item_widget.frames_dir_edit.isEnabled() is False + assert item_widget.solute_elements_edit.isEnabled() is True + assert jobs == [ + ( + "existing", + DebyerPDFExistingPartialsJob( + project_dir=project_dir.resolve(), + solute_elements=("Pb",), + ), + ) + ] + + item_widget.solute_elements_edit.setText("Cs, Pb, I") + item_widget._on_editor_changed() + jobs = window.existing_partials_jobs_in_order() + + assert jobs[0][1].solute_elements == ("Cs", "Pb", "I") + + +def test_compute_experimental_fit_metrics_interpolates_model_gr(): + metrics = compute_experimental_fit_metrics( + model_r_values=np.array([0.0, 1.0, 2.0], dtype=float), + model_g_values=np.array([1.0, 2.0, 3.0], dtype=float), + experimental_r_values=np.array([0.5, 1.5], dtype=float), + experimental_g_values=np.array([1.5, 2.5], dtype=float), + ) + + assert metrics is not None + assert metrics.r_squared == pytest.approx(1.0) + assert metrics.rmse == pytest.approx(0.0) + assert metrics.mae == pytest.approx(0.0) + assert metrics.point_count == 2 + assert metrics.r_min == pytest.approx(0.5) + assert metrics.r_max == pytest.approx(1.5) + + +def test_fit_coordination_peak_from_r_recovers_gaussian_area(): + r_values = np.linspace(1.5, 3.5, 121) + expected_cn = 4.25 + center = 2.62 + sigma = 0.11 + baseline = 0.35 + 0.08 * (r_values - center) + r_values_distribution = baseline + ( + expected_cn + / (sigma * np.sqrt(2.0 * np.pi)) + * np.exp(-0.5 * ((r_values - center) / sigma) ** 2) + ) + + result = fit_coordination_peak_from_r( + r_values=r_values, + r_distribution_values=r_values_distribution, + r_min=2.2, + r_max=3.05, + initial_center=2.6, + initial_sigma=0.12, + ) + + assert result.coordination_number == pytest.approx(expected_cn, rel=0.03) + assert result.center == pytest.approx(center, abs=0.02) + assert result.sigma == pytest.approx(sigma, abs=0.02) + assert result.r_squared > 0.99 + + +def test_running_debyer_average_matches_batch_average_without_frame_cache(): + r_values = np.linspace(0.1, 6.0, 240, dtype=float) + batch_outputs = [] + running_average = debyer_workflow._RunningDebyerAverage() + + for index in range(250): + scale = 1.0 + float(index) + columns = { + "sum": scale * np.sin(r_values), + "Pb-I": scale * np.cos(r_values), + } + if index >= 25: + columns["I-I"] = scale * (r_values**2) + batch_outputs.append((r_values, columns)) + running_average.add_frame(r_values, columns) + + batch_r, batch_columns, batch_values = ( + debyer_workflow._average_frame_outputs(batch_outputs) + ) + running_r, running_columns, running_values = running_average.average() + raw_cache_bytes = sum( + r_array.nbytes + sum(values.nbytes for values in columns.values()) + for r_array, columns in batch_outputs + ) + + assert np.allclose(running_r, batch_r) + assert running_columns == batch_columns + for key in batch_columns: + assert np.allclose(running_values[key], batch_values[key]) + assert running_average.memory_bytes < raw_cache_bytes / 100 + + def test_debyer_workflow_averages_and_persists_calculation(tmp_path): fake_debyer = _write_fake_debyer(tmp_path / "debyer") frames_dir = _build_frames_dir(tmp_path) @@ -186,6 +573,24 @@ def test_debyer_workflow_averages_and_persists_calculation(tmp_path): assert "# processed_frames:" in averaged_text assert "# total_frames:" in averaged_text assert "# columns: sum" in averaged_text + _output_r, output_values = parse_debyer_output_file( + result.averaged_output_file + ) + assert "solute-solute" in output_values + assert "solute-solvent" in output_values + assert "solvent-solvent" in output_values + assert np.allclose( + output_values["solute-solute"], + result.partial_values["Pb-Pb"], + ) + assert np.allclose( + output_values["solute-solvent"], + result.partial_values["Pb-I"], + ) + assert np.allclose( + output_values["solvent-solvent"], + result.partial_values["I-I"], + ) summaries = list_saved_debyer_calculations(project_dir) assert len(summaries) == 1 @@ -193,12 +598,582 @@ def test_debyer_workflow_averages_and_persists_calculation(tmp_path): assert loaded.filename_prefix == "demo_pdf" assert np.allclose(loaded.total_values, result.total_values) assert loaded.solute_elements == ("Pb",) + assert sorted(loaded.partial_values) == ["I-I", "Pb-I", "Pb-Pb"] assert loaded.processed_frame_count == loaded.frame_count assert "Pb-I" in loaded.partial_peak_markers assert loaded.partial_peak_markers["Pb-I"] assert loaded.peak_finder_settings.max_peak_count >= 0 +def test_pdf_batch_worker_runs_projects_in_sequence(qapp, tmp_path): + fake_debyer = _write_fake_debyer(tmp_path / "debyer") + first_frames_dir = _build_frames_dir(tmp_path / "first") + second_frames_dir = _build_frames_dir(tmp_path / "second") + entries = [ + ( + "first", + DebyerPDFSettings( + project_dir=tmp_path / "first_project", + frames_dir=first_frames_dir, + filename_prefix="first_pdf", + mode="PDF", + from_value=0.5, + to_value=5.0, + step_value=0.01, + box_dimensions=(10.0, 10.0, 10.0), + atom_count=3, + solute_elements=("Pb", "I"), + ), + ), + ( + "second", + DebyerPDFSettings( + project_dir=tmp_path / "second_project", + frames_dir=second_frames_dir, + filename_prefix="second_pdf", + mode="PDF", + from_value=0.5, + to_value=5.0, + step_value=0.01, + box_dimensions=(11.0, 11.0, 11.0), + atom_count=3, + solute_elements=("Pb", "I"), + ), + ), + ] + worker = DebyerPDFBatchWorker(entries, debyer_executable=fake_debyer) + started_items: list[str] = [] + finished_results: list[DebyerPDFCalculation] = [] + worker.item_started.connect( + lambda item_id, _index, _total: started_items.append(item_id) + ) + worker.finished.connect(lambda results: finished_results.extend(results)) + + worker.run() + + assert started_items == ["first", "second"] + assert [result.filename_prefix for result in finished_results] == [ + "first_pdf", + "second_pdf", + ] + assert all( + result.averaged_output_file.is_file() for result in finished_results + ) + assert all(not result.is_partial_average for result in finished_results) + for (_item_id, settings), result in zip(entries, finished_results): + summaries = list_saved_debyer_calculations(settings.project_dir) + assert len(summaries) == 1 + assert summaries[0].filename_prefix == result.filename_prefix + loaded = load_debyer_calculation(summaries[0].calculation_dir) + assert loaded.project_dir == settings.project_dir.resolve() + assert loaded.frames_dir == settings.frames_dir.resolve() + assert loaded.filename_prefix == result.filename_prefix + assert loaded.mode == result.mode + assert loaded.frame_count == result.frame_count + assert loaded.processed_frame_count == loaded.frame_count + assert np.allclose(loaded.total_values, result.total_values) + assert loaded.solute_elements == ("Pb", "I") + _output_r, output_values = parse_debyer_output_file( + loaded.averaged_output_file + ) + assert "solute-solute" in output_values + assert "solute-solvent" not in output_values + assert "solvent-solvent" not in output_values + assert np.allclose( + output_values["solute-solute"], + sum(loaded.partial_values.values()), + ) + assert (loaded.calculation_dir / "calculation.json").is_file() + assert loaded.averaged_output_file.is_file() + + +def test_pdf_existing_partials_worker_updates_saved_calculations( + qapp, + tmp_path, +): + fake_debyer = _write_fake_debyer(tmp_path / "debyer") + frames_dir = _build_frames_dir( + tmp_path, + xyz=("2\n" "frame\n" "Na 0.0 0.0 0.0\n" "Cl 2.0 0.0 0.0\n"), + ) + project_dir = tmp_path / "project" + result = DebyerPDFWorkflow( + DebyerPDFSettings( + project_dir=project_dir, + frames_dir=frames_dir, + filename_prefix="existing_average", + mode="PDF", + from_value=0.5, + to_value=5.0, + step_value=0.01, + box_dimensions=(10.0, 10.0, 10.0), + atom_count=2, + solute_elements=(), + ), + debyer_executable=fake_debyer, + ).run() + _r_values, before_values = parse_debyer_output_file( + result.averaged_output_file + ) + assert "solute-solute" not in before_values + worker = DebyerPDFExistingPartialsWorker( + [ + ( + "project", + DebyerPDFExistingPartialsJob( + project_dir=project_dir, + solute_elements=("Pb",), + ), + ) + ] + ) + updated_results: list[DebyerPDFCalculation] = [] + progress_messages: list[str] = [] + worker.item_progress.connect( + lambda _item_id, _processed, _total, message: progress_messages.append( + message + ) + ) + worker.finished.connect(lambda results: updated_results.extend(results)) + + worker.run() + + assert len(updated_results) == 1 + assert progress_messages + _r_values, after_values = parse_debyer_output_file( + result.averaged_output_file + ) + assert "solute-solute" in after_values + assert "solute-solvent" in after_values + assert "solvent-solvent" in after_values + loaded = load_debyer_calculation(result.calculation_dir) + assert loaded.solute_elements == ("Pb",) + assert sorted(loaded.partial_values) == ["I-I", "Pb-I", "Pb-Pb"] + + +def test_debyer_workflow_parallel_jobs_match_serial_average(tmp_path): + fake_debyer = _write_fake_debyer(tmp_path / "debyer") + frames_dir = _build_frames_dir(tmp_path, frame_count=8) + + serial = DebyerPDFWorkflow( + DebyerPDFSettings( + project_dir=tmp_path / "serial_project", + frames_dir=frames_dir, + filename_prefix="serial_pdf", + mode="PDF", + from_value=0.5, + to_value=15.0, + step_value=0.01, + box_dimensions=(10.0, 10.0, 10.0), + atom_count=3, + store_frame_outputs=False, + solute_elements=("Pb",), + max_parallel_jobs=1, + ), + debyer_executable=fake_debyer, + ).run() + parallel = DebyerPDFWorkflow( + DebyerPDFSettings( + project_dir=tmp_path / "parallel_project", + frames_dir=frames_dir, + filename_prefix="parallel_pdf", + mode="PDF", + from_value=0.5, + to_value=15.0, + step_value=0.01, + box_dimensions=(10.0, 10.0, 10.0), + atom_count=3, + store_frame_outputs=False, + solute_elements=("Pb",), + max_parallel_jobs=4, + ), + debyer_executable=fake_debyer, + ).run() + + assert parallel.parallel_jobs == 4 + assert parallel.processed_frame_count == serial.processed_frame_count == 8 + assert np.allclose(parallel.total_values, serial.total_values) + assert parallel.partial_values.keys() == serial.partial_values.keys() + for key in parallel.partial_values: + assert np.allclose( + parallel.partial_values[key], + serial.partial_values[key], + ) + loaded = load_debyer_calculation(parallel.calculation_dir) + assert loaded.parallel_jobs == 4 + averaged_text = parallel.averaged_output_file.read_text(encoding="utf-8") + assert "# parallel_jobs: 4" in averaged_text + + +def test_debyer_workflow_defaults_solute_elements_when_omitted(tmp_path): + fake_debyer = _write_fake_debyer(tmp_path / "debyer") + frames_dir = _build_frames_dir(tmp_path) + result = DebyerPDFWorkflow( + DebyerPDFSettings( + project_dir=tmp_path / "project", + frames_dir=frames_dir, + filename_prefix="default_solute_pdf", + mode="PDF", + from_value=0.5, + to_value=15.0, + step_value=0.01, + box_dimensions=(10.0, 10.0, 10.0), + atom_count=3, + store_frame_outputs=False, + solute_elements=(), + ), + debyer_executable=fake_debyer, + ).run() + + assert result.solute_elements == ("Pb", "I") + loaded = load_debyer_calculation(result.calculation_dir) + assert loaded.solute_elements == ("Pb", "I") + + +def test_debyer_workflow_checkpoints_sparse_running_averages(tmp_path): + fake_debyer = _write_fake_debyer(tmp_path / "debyer") + frames_dir = _build_frames_dir(tmp_path, frame_count=12) + project_dir = tmp_path / "project" + workflow = DebyerPDFWorkflow( + DebyerPDFSettings( + project_dir=project_dir, + frames_dir=frames_dir, + filename_prefix="sparse_preview_pdf", + mode="PDF", + from_value=0.5, + to_value=15.0, + step_value=0.01, + box_dimensions=(10.0, 10.0, 10.0), + atom_count=3, + store_frame_outputs=False, + solute_elements=("Pb",), + ), + debyer_executable=fake_debyer, + ) + preview_counts: list[int | None] = [] + + result = workflow.run( + preview_callback=lambda calculation: preview_counts.append( + calculation.processed_frame_count + ) + ) + + assert preview_counts == [12] + assert result.processed_frame_count == 12 + assert result.is_partial_average is False + + +def test_debyer_workflow_respects_live_preview_decision(tmp_path): + fake_debyer = _write_fake_debyer(tmp_path / "debyer") + frames_dir = _build_frames_dir(tmp_path, frame_count=12) + project_dir = tmp_path / "project" + workflow = DebyerPDFWorkflow( + DebyerPDFSettings( + project_dir=project_dir, + frames_dir=frames_dir, + filename_prefix="live_preview_pdf", + mode="PDF", + from_value=0.5, + to_value=15.0, + step_value=0.01, + box_dimensions=(10.0, 10.0, 10.0), + atom_count=3, + store_frame_outputs=False, + solute_elements=("Pb",), + ), + debyer_executable=fake_debyer, + ) + previews = [] + preview_enabled = {"value": False} + + def update_preview_toggle(processed, _total, _message): + if processed == 6: + preview_enabled["value"] = True + elif processed == 7: + preview_enabled["value"] = False + + def should_preview(processed, _total, checkpoint_due): + if processed == 6: + assert checkpoint_due is False + return preview_enabled["value"] + + result = workflow.run( + progress_callback=update_preview_toggle, + preview_callback=previews.append, + preview_decision_callback=should_preview, + ) + + assert [preview.processed_frame_count for preview in previews] == [6] + assert result.processed_frame_count == 12 + assert result.is_partial_average is False + loaded = load_debyer_calculation(result.calculation_dir) + assert loaded.processed_frame_count == 12 + + +def test_debyer_worker_preview_toggle_requests_next_average(tmp_path): + settings = DebyerPDFSettings( + project_dir=tmp_path / "project", + frames_dir=tmp_path / "frames", + filename_prefix="worker_preview", + ) + worker = DebyerPDFWorker(settings, preview_enabled=False) + + assert worker._should_emit_preview(1, 10, True) is False + + worker.set_preview_enabled(True) + assert worker._should_emit_preview(2, 10, False) is True + worker._emit_preview(object()) + assert worker._should_emit_preview(3, 10, False) is False + assert worker._should_emit_preview(10, 10, True) is True + + worker.set_preview_enabled(False) + assert worker._should_emit_preview(10, 10, True) is False + + +def test_debyer_window_clamps_rejected_r_range_maximum_to_half_min_box( + qapp, + tmp_path, +): + del qapp + frames_dir = _build_frames_dir(tmp_path) + project_dir = tmp_path / "project" + window = DebyerPDFMainWindow() + window.project_dir_edit.setText(str(project_dir)) + window.frames_dir_edit.setText(str(frames_dir)) + window.filename_prefix_edit.setText("box_limited_pdf") + window.from_edit.setText("0.5") + window.to_edit.setText("15.0") + window.step_edit.setText("0.01") + window.box_a_edit.setText("20.0") + window.box_b_edit.setText("8.0") + window.box_c_edit.setText("12.0") + window.atom_count_edit.setText("3") + + settings = window._build_settings() + + assert settings.to_value == pytest.approx(4.0) + assert window.to_edit.text() == "4" + assert ( + "half of the minimum box dimension" + in window.statusBar().currentMessage() + ) + window.close() + + +def test_debyer_window_defaults_solutes_from_inspected_frames(qapp, tmp_path): + del qapp + pb_i_frames = _build_frames_dir(tmp_path / "pb_i") + cs_pb_i_frames = _build_frames_dir( + tmp_path / "cs_pb_i", + xyz=( + "4\n" + "frame\n" + "Cs 0.0 0.0 0.0\n" + "Pb 2.0 0.0 0.0\n" + "I 0.0 2.0 0.0\n" + "I 0.0 0.0 2.0\n" + ), + ) + window = DebyerPDFMainWindow() + + window.frames_dir_edit.setText(str(pb_i_frames)) + window._inspect_frames_dir() + assert window.solute_elements_edit.text() == "Pb, I" + assert "Default solutes: Pb, I" in window.frames_summary_label.text() + + window.solute_elements_edit.clear() + window.frames_dir_edit.setText(str(cs_pb_i_frames)) + window._inspect_frames_dir() + assert window.solute_elements_edit.text() == "Cs, Pb, I" + assert "Default solutes: Cs, Pb, I" in window.frames_summary_label.text() + window.close() + + +def test_debyer_window_splitter_handle_is_grabbable(qapp): + del qapp + window = DebyerPDFMainWindow() + + tab_names = [ + window.result_tabs.tabText(index) + for index in range(window.result_tabs.count()) + ] + assert "Shape Function Analysis" in tab_names + assert "Fit" in tab_names + assert tab_names.index("Shape Function Analysis") < tab_names.index("Fit") + assert window._main_splitter.handleWidth() >= 12 + assert window._main_splitter.handle(1).toolTip() + assert window._main_splitter.widget(1).minimumSizeHint().width() < 800 + assert ( + window.findChild(QWidget, "pdfPlotControls").minimumSizeHint().width() + < 800 + ) + window.close() + + +def test_debyer_window_fits_coordination_number_from_r_trace(qapp, tmp_path): + r_values = np.linspace(1.5, 3.5, 121) + expected_cn = 3.75 + center = 2.55 + sigma = 0.13 + r_distribution = 0.2 + ( + expected_cn + / (sigma * np.sqrt(2.0 * np.pi)) + * np.exp(-0.5 * ((r_values - center) / sigma) ** 2) + ) + calculation = DebyerPDFCalculation( + calculation_id="fit_demo", + calculation_dir=tmp_path, + created_at="2026-05-13T00:00:00", + project_dir=tmp_path, + frames_dir=tmp_path, + frame_format="xyz", + frame_count=1, + filename_prefix="fit_demo", + mode="RDF", + from_value=1.5, + to_value=3.5, + step_value=float(r_values[1] - r_values[0]), + box_dimensions=(20.0, 20.0, 20.0), + box_source=None, + box_source_kind=None, + atom_count=2, + rho0=1.0, + store_frame_outputs=False, + frame_output_dir=None, + averaged_output_file=tmp_path / "averaged_raw.txt", + solute_elements=("Pb",), + parallel_jobs=1, + r_values=r_values, + total_values=r_distribution, + partial_values={"Pb-I": r_distribution}, + ) + window = DebyerPDFMainWindow() + window._apply_loaded_calculation(calculation) + qapp.processEvents() + + trace_index = window.coordination_fit_trace_combo.findData("partial:Pb-I") + assert trace_index >= 0 + window.coordination_fit_trace_combo.setCurrentIndex(trace_index) + window.coordination_fit_r_min_spin.setValue(2.1) + window.coordination_fit_r_max_spin.setValue(3.0) + window.coordination_fit_center_spin.setValue(2.5) + window.coordination_fit_sigma_spin.setValue(0.12) + window._fit_coordination_number() + qapp.processEvents() + + assert window.representation_combo.currentText() == "R(r)" + assert window.coordination_fit_results_table.rowCount() == 1 + fitted_cn = float(window.coordination_fit_results_table.item(0, 5).text()) + assert fitted_cn == pytest.approx(expected_cn, rel=0.05) + assert "CN =" in window.coordination_fit_status_label.text() + window.close() + + +def test_debyer_window_applies_solute_groups_after_calculation( + qapp, + tmp_path, +): + fake_debyer = _write_fake_debyer(tmp_path / "debyer") + frames_dir = _build_frames_dir( + tmp_path, + xyz=("2\n" "frame\n" "Na 0.0 0.0 0.0\n" "Cl 2.0 0.0 0.0\n"), + ) + project_dir = tmp_path / "project" + result = DebyerPDFWorkflow( + DebyerPDFSettings( + project_dir=project_dir, + frames_dir=frames_dir, + filename_prefix="forgot_solutes", + mode="PDF", + from_value=0.5, + to_value=15.0, + step_value=0.01, + box_dimensions=(10.0, 10.0, 10.0), + atom_count=2, + store_frame_outputs=False, + solute_elements=(), + ), + debyer_executable=fake_debyer, + ).run() + assert result.solute_elements == () + + window = DebyerPDFMainWindow(initial_project_dir=project_dir) + qapp.processEvents() + assert not any( + window.trace_table.item(row, 3) is not None + and window.trace_table.item(row, 3).text() == "Group" + for row in range(window.trace_table.rowCount()) + ) + + window.solute_elements_edit.setText("Pb") + window._apply_solute_groups_from_ui() + qapp.processEvents() + + assert window._current_calculation is not None + assert window._current_calculation.solute_elements == ("Pb",) + assert any( + window.trace_table.item(row, 3) is not None + and window.trace_table.item(row, 3).text() == "Group" + and window.trace_table.item(row, 2) is not None + and window.trace_table.item(row, 2).text() == "solute-solvent" + for row in range(window.trace_table.rowCount()) + ) + loaded = load_debyer_calculation(result.calculation_dir) + assert loaded.solute_elements == ("Pb",) + _output_r, output_values = parse_debyer_output_file( + loaded.averaged_output_file + ) + assert "solute-solute" in output_values + assert "solute-solvent" in output_values + assert "solvent-solvent" in output_values + assert sorted(loaded.partial_values) == ["I-I", "Pb-I", "Pb-Pb"] + assert "Solute elements: Pb" in window.calculation_info_label.text() + window.close() + + +def test_debyer_workflow_cancellation_saves_partial_average(tmp_path): + fake_debyer = _write_fake_debyer(tmp_path / "debyer") + frames_dir = _build_frames_dir(tmp_path) + project_dir = tmp_path / "project" + stop_requested = {"value": False} + workflow = DebyerPDFWorkflow( + DebyerPDFSettings( + project_dir=project_dir, + frames_dir=frames_dir, + filename_prefix="cancelled_pdf", + mode="PDF", + from_value=0.5, + to_value=15.0, + step_value=0.01, + box_dimensions=(10.0, 10.0, 10.0), + atom_count=3, + store_frame_outputs=False, + solute_elements=("Pb",), + ), + debyer_executable=fake_debyer, + ) + + def request_stop_after_first_frame(processed, _total, _message): + if processed >= 1: + stop_requested["value"] = True + + result = workflow.run( + progress_callback=request_stop_after_first_frame, + cancel_callback=lambda: stop_requested["value"], + ) + + assert result.processed_frame_count == 1 + assert result.frame_count == 2 + assert result.is_partial_average is True + assert np.allclose(result.total_values, np.array([1.0, 1.1, 1.2])) + + loaded = load_debyer_calculation(result.calculation_dir) + assert loaded.processed_frame_count == 1 + assert loaded.frame_count == 2 + assert loaded.is_partial_average is True + assert np.allclose(loaded.total_values, result.total_values) + + def test_debyer_load_backfills_missing_peak_metadata(tmp_path): fake_debyer = _write_fake_debyer(tmp_path / "debyer") frames_dir = _build_frames_dir(tmp_path) @@ -288,6 +1263,20 @@ def test_debyer_window_loads_saved_calculation(qapp, tmp_path, monkeypatch): assert average_row is not None assert grouped_row is not None assert partial_row is not None + group_colors = { + key: window._trace_colors[key] + for key in ( + "group:solute-solute", + "group:solute-solvent", + "group:solvent-solvent", + ) + } + assert group_colors == { + "group:solute-solute": "#cc79a7", + "group:solute-solvent": "#e69f00", + "group:solvent-solvent": "#009e73", + } + assert len(set(group_colors.values())) == 3 average_tag_box = window.trace_table.cellWidget(average_row, 1) grouped_tag_box = window.trace_table.cellWidget(grouped_row, 1) assert average_tag_box is not None @@ -434,6 +1423,74 @@ def __init__(self, xdata, ydata, inaxes, button=1): window.close() +def test_debyer_window_loads_experimental_gr_trace(qapp, tmp_path): + fake_debyer = _write_fake_debyer(tmp_path / "debyer") + frames_dir = _build_frames_dir(tmp_path) + project_dir = tmp_path / "project" + workflow = DebyerPDFWorkflow( + DebyerPDFSettings( + project_dir=project_dir, + frames_dir=frames_dir, + filename_prefix="experimental_demo", + mode="PDF", + from_value=0.5, + to_value=15.0, + step_value=0.01, + box_dimensions=(12.0, 12.0, 12.0), + atom_count=3, + store_frame_outputs=False, + solute_elements=("Pb",), + ), + debyer_executable=fake_debyer, + ) + workflow.run() + experimental_path = tmp_path / "experimental_gr.txt" + experimental_path.write_text( + "r(A) g(r)\n" "0.5 1.5\n" "1.0 1.65\n" "1.5 1.8\n", + encoding="utf-8", + ) + + window = DebyerPDFMainWindow(initial_project_dir=project_dir) + qapp.processEvents() + window._load_experimental_file(experimental_path) + qapp.processEvents() + + experimental_row = None + for row in range(window.trace_table.rowCount()): + kind_item = window.trace_table.item(row, 3) + if kind_item is not None and kind_item.text() == "Experimental": + experimental_row = row + break + assert experimental_row is not None + visible_box = window.trace_table.cellWidget(experimental_row, 0) + assert visible_box is not None + assert visible_box.isChecked() is True + experimental_line = next( + line + for line in window.figure.axes[0].get_lines() + if line.get_label().startswith("Experimental g(r)") + ) + assert experimental_line.get_linestyle() == "--" + assert "R^2 = 1.0000" in window._experimental_fit_metrics_text() + assert any( + "R^2 = 1.0000" in text_artist.get_text() + for text_artist in window.figure.axes[0].texts + ) + + window._toggle_experimental_trace() + qapp.processEvents() + assert window._trace_visibility["experimental"] is False + assert window.experimental_toggle_button.text() == "Show Experimental" + + window.fit_box_checkbox.setChecked(False) + qapp.processEvents() + assert not any( + "R^2 =" in text_artist.get_text() + for text_artist in window.figure.axes[0].texts + ) + window.close() + + def test_debyer_window_can_export_active_traces(qapp, tmp_path, monkeypatch): fake_debyer = _write_fake_debyer(tmp_path / "debyer") frames_dir = _build_frames_dir(tmp_path) diff --git a/tests/test_periodic_table_ui.py b/tests/test_periodic_table_ui.py new file mode 100644 index 0000000..39c1951 --- /dev/null +++ b/tests/test_periodic_table_ui.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import os + +from PySide6.QtWidgets import QApplication + +from saxshell.ui.periodic_table import ( + PERIODIC_TABLE_ELEMENTS, + PeriodicTableWidget, + element_by_symbol, +) + + +def qapp(): + os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") + app = QApplication.instance() + if app is None: + app = QApplication([]) + return app + + +def test_periodic_table_widget_selects_element_symbol(): + qapp() + widget = PeriodicTableWidget() + selected: list[str] = [] + widget.element_selected.connect(selected.append) + + widget.select_element("cs") + + assert widget.selected_symbol() == "Cs" + assert selected == ["Cs"] + assert element_by_symbol("pb").name == "Lead" + assert len(PERIODIC_TABLE_ELEMENTS) == 118 diff --git a/tests/test_representativefinder.py b/tests/test_representativefinder.py index 8890d46..8ec977a 100644 --- a/tests/test_representativefinder.py +++ b/tests/test_representativefinder.py @@ -46,6 +46,11 @@ from saxshell.representativefinder.cli import ( main as representativefinder_cli_main, ) +from saxshell.representativefinder.ui.batch_queue_window import ( + RepresentativeFinderBatchJob, + RepresentativeFinderBatchQueueWindow, + RepresentativeFinderBatchWorker, +) from saxshell.representativefinder.ui.main_window import ( RepresentativeStructureFinderMainWindow, ) @@ -391,11 +396,26 @@ def test_representativefinder_result_json_preserves_analysis_details(tmp_path): assert loaded.candidates[0].score_total is not None -def test_representativefinder_single_atom_shortcuts_full_analysis(tmp_path): +def test_representativefinder_single_atom_shortcuts_full_analysis( + tmp_path, + monkeypatch, +): stoich_dir = _build_single_atom_test_folder(tmp_path) progress_events: list[tuple[int, int, str]] = [] log_messages: list[str] = [] + def fail_measurement(*_args, **_kwargs): + pytest.fail( + "Single-atom representative selection should not run full " + "bond/angle measurement." + ) + + monkeypatch.setattr( + "saxshell.representativefinder.workflow." + "BondAnalyzer.measure_structure_data", + fail_measurement, + ) + result = analyze_representative_structure_folder( stoich_dir, settings=RepresentativeFinderSettings(), @@ -424,6 +444,9 @@ def test_representativefinder_single_atom_shortcuts_full_analysis(tmp_path): "Aggregating bond and angle distributions" in message for _processed, _total, message in progress_events ) + assert not any( + "Measuring " in message for _p, _t, message in progress_events + ) assert not any( "Scoring " in message for _p, _t, message in progress_events ) @@ -1266,6 +1289,118 @@ def test_representativefinder_input_inspection_discovers_stoichiometry_subfolder ] +def test_representativefinder_batch_queue_prefills_project_clusters_and_all_mode( + qapp, + tmp_path, + monkeypatch, +): + del qapp + monkeypatch.setenv( + "SAXSHELL_BONDANALYSIS_PRESETS_PATH", + str(tmp_path / "bondanalysis_presets.json"), + ) + project_dir = tmp_path / "project" + root_dir, _pb_dir, _sn_dir = _build_multi_stoichiometry_root(tmp_path) + manager = SAXSProjectManager() + settings = manager.create_project(project_dir) + settings.clusters_dir = str(root_dir.resolve()) + manager.save_project(settings) + + window = RepresentativeFinderBatchQueueWindow( + initial_project_dir=project_dir, + ) + + assert window.queue_list.count() == 1 + widget = window.queue_list.itemWidget(window.queue_list.item(0)) + assert widget.project_dir_edit.text() == str(project_dir.resolve()) + assert widget.clusters_dir_edit.text() == str(root_dir.resolve()) + assert "representativefinder_batch_cluster_root" in ( + widget.output_dir_edit.text() + ) + assert ( + widget.analysis_mode_label.text() == "All Discovered Stoichiometries" + ) + assert window.preset_combo.count() >= 1 + + window.close() + + +def test_representativefinder_batch_worker_publishes_project_results_and_restores_ui( + qapp, + tmp_path, + monkeypatch, +): + del qapp + monkeypatch.setenv( + "SAXSHELL_BONDANALYSIS_PRESETS_PATH", + str(tmp_path / "bondanalysis_presets.json"), + ) + project_dir = tmp_path / "project" + root_dir, _pb_dir, _sn_dir = _build_multi_stoichiometry_root(tmp_path) + manager = SAXSProjectManager() + project_settings = manager.create_project(project_dir) + project_settings.clusters_dir = str(root_dir.resolve()) + manager.save_project(project_settings) + settings = RepresentativeFinderSettings( + bond_pairs=( + BondPairDefinition("Pb", "I", 3.2), + BondPairDefinition("Sn", "Br", 3.2), + ), + angle_triplets=( + AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2), + AngleTripletDefinition("Sn", "Br", "Br", 3.2, 3.2), + ), + solvent_weight=0.0, + parallel_workers=1, + ) + output_dir = project_dir / "representative_finder" / "batch_run" + config = build_representativefinder_run_config( + project_dir=project_dir, + input_dir=root_dir, + output_dir=output_dir, + analysis_mode="all", + settings=settings, + overwrite_existing=False, + ) + job = RepresentativeFinderBatchJob( + project_dir=project_dir.resolve(), + clusters_dir=root_dir.resolve(), + output_dir=output_dir.resolve(), + config=config, + ) + worker = RepresentativeFinderBatchWorker([("job-1", job)]) + finished_results: list[object] = [] + failed_items: list[tuple[str, str]] = [] + changed_projects: list[str] = [] + worker.finished.connect(finished_results.append) + worker.failed.connect( + lambda item_id, message: failed_items.append((item_id, message)) + ) + worker.project_results_changed.connect(changed_projects.append) + + worker.run() + + assert failed_items == [] + assert changed_projects == [str(project_dir.resolve())] + assert len(finished_results) == 1 + assert finished_results[0][0].completed_count == 2 + state = manager.inspect_representative_structures(project_dir) + assert state.representative_count == 2 + + window = RepresentativeStructureFinderMainWindow( + initial_project_dir=project_dir, + initial_input_path=root_dir, + ) + assert window.stoichiometry_table.rowCount() == 2 + assert window.stoichiometry_table.item(0, 3).text() == "Complete" + assert window.stoichiometry_table.item(1, 3).text() == "Complete" + assert ( + window.run_status_label.text() + == "Representative selection: restored from saved project analysis" + ) + window.close() + + def test_representativefinder_window_builds_split_scrollable_layout( qapp, tmp_path, diff --git a/tests/test_saxs_prefit.py b/tests/test_saxs_prefit.py index 25dbac5..4fe72e4 100644 --- a/tests/test_saxs_prefit.py +++ b/tests/test_saxs_prefit.py @@ -59,6 +59,9 @@ SCALED_SOLVENT_MONOSQ_TEMPLATE = ( "template_pydream_monosq_normalized_scaled_solvent" ) +MODEL_SCALED_SOLVENT_MONOSQ_TEMPLATE = ( + "template_pydream_monosq_normalized_scaled_solvent_model_scale" +) def _write_component_file(path, q_values, intensities): @@ -1655,6 +1658,88 @@ def test_scaled_solvent_monosq_template_scales_solvent_with_global_scale(): ) +def test_model_scaled_solvent_monosq_template_transforms_model_curve(): + template_module = load_template_module( + MODEL_SCALED_SOLVENT_MONOSQ_TEMPLATE + ) + q_values = np.linspace(0.05, 0.3, 8) + component = np.linspace(10.0, 17.0, 8) + solvent = np.linspace(1.5, 2.2, 8) + + raw_model = template_module.raw_monosq_scaled_solvent_profile( + q_values, + solvent, + [component], + [0.6], + solv_w=0.5, + eff_r=9.0, + vol_frac=0.0, + ) + model = template_module.lmfit_model_profile( + q_values, + solvent, + [component], + w0=0.6, + solv_w=0.5, + offset=0.05, + eff_r=9.0, + vol_frac=0.0, + scale=2e-3, + ) + + assert np.allclose(raw_model, (0.6 * component) + (0.5 * solvent)) + assert np.allclose(model, (2e-3 * raw_model) + 0.05) + + +def test_model_scaled_solvent_monosq_likelihood_uses_unmodified_experiment(): + template_module = load_template_module( + MODEL_SCALED_SOLVENT_MONOSQ_TEMPLATE + ) + q_values = np.linspace(0.05, 0.3, 8) + component = np.linspace(10.0, 17.0, 8) + solvent = np.linspace(1.5, 2.2, 8) + params = np.asarray([0.6, 0.5, 0.05, 9.0, 0.0, 2e-3], dtype=float) + expected_model = template_module.lmfit_model_profile( + q_values, + solvent, + [component], + w0=0.6, + solv_w=0.5, + offset=0.05, + eff_r=9.0, + vol_frac=0.0, + scale=2e-3, + ) + raw_model = template_module.raw_monosq_scaled_solvent_profile( + q_values, + solvent, + [component], + [0.6], + solv_w=0.5, + eff_r=9.0, + vol_frac=0.0, + ) + + template_module.q_values = q_values + template_module.theoretical_intensities = [component] + template_module.solvent_intensities = solvent + template_module.experimental_intensities = expected_model + matching_log_likelihood = ( + template_module.log_likelihood_monosq_scaled_solvent_model_scale( + params + ) + ) + + template_module.experimental_intensities = raw_model + raw_model_log_likelihood = ( + template_module.log_likelihood_monosq_scaled_solvent_model_scale( + params + ) + ) + + assert matching_log_likelihood > raw_model_log_likelihood + + def test_scaled_solvent_monosq_prefit_evaluates_scaled_solvent_contribution( tmp_path, ): @@ -1940,6 +2025,55 @@ def test_scaled_solvent_monosq_prefit_exposes_physical_vol_frac_target( assert workflow.solvent_contribution_is_scaled_by_global_scale() +def test_model_scaled_solvent_monosq_prefit_metadata_targets_model_scale( + tmp_path, +): + spec = load_template_spec(MODEL_SCALED_SOLVENT_MONOSQ_TEMPLATE) + + assert spec.display_name == ( + "pyDREAM MonoSQ Normalized " + "(Scaled Solvent Weight, Model Scale/Offset)" + ) + assert ( + spec.solution_scattering_support.volume_fraction_parameter + == "vol_frac" + ) + assert spec.solution_scattering_support.volume_fraction_kind == "solute" + assert ( + spec.solution_scattering_support.volume_fraction_source == "physical" + ) + assert ( + spec.solution_scattering_support.solvent_contribution_scale_mode + == "global_scale" + ) + assert spec.prefit_support.auto_apply_autoscale_on_load + assert spec.prefit_support.autoscale_bounds_mode == "adaptive" + scale_entry = next( + parameter for parameter in spec.parameters if parameter.name == "scale" + ) + assert scale_entry.vary is True + + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.selected_model_template = MODEL_SCALED_SOLVENT_MONOSQ_TEMPLATE + manager.save_project(settings) + workflow = SAXSPrefitWorkflow(project_dir) + + assert workflow.supports_volume_fraction_estimator() + assert workflow.volume_fraction_estimator_target() == ( + "vol_frac", + "solute", + ) + assert workflow.solution_scattering_volume_fraction_target() == ( + "vol_frac", + "solute", + "physical", + ) + assert workflow.solvent_weight_estimator_target() == "solv_w" + assert workflow.solvent_contribution_is_scaled_by_global_scale() + + def test_poly_lma_prefit_workflow_exposes_solvent_weight_target(tmp_path): project_dir, _paths, _radius = _build_poly_lma_geometry_project(tmp_path) workflow = SAXSPrefitWorkflow(project_dir) diff --git a/tests/test_saxs_template_installation.py b/tests/test_saxs_template_installation.py index 8edfb23..68c0b6e 100644 --- a/tests/test_saxs_template_installation.py +++ b/tests/test_saxs_template_installation.py @@ -3,10 +3,12 @@ import textwrap from pathlib import Path +import numpy as np import pytest from saxshell.saxs._model_templates import ( list_template_specs, + load_template_module, load_template_spec, ) from saxshell.saxs.template_installation import ( @@ -15,6 +17,9 @@ ) TEMPLATE_CANDIDATE_DIR = Path("tests/template_candidates") +CHARGED_MONOSQ_TEMPLATE = ( + "template_pydream_charged_monosq_normalized_scaled_solvent" +) def _write_template(path: Path, body: str) -> Path: @@ -177,12 +182,121 @@ def test_template_listing_hides_deprecated_by_default(): spec.name for spec in list_template_specs(include_deprecated=True) } + assert CHARGED_MONOSQ_TEMPLATE in visible_names assert "template_pd_likelihood_monosq_decoupled" not in visible_names assert "template_pd_likelihood_monosq_decoupled" in all_names assert "template_pydream_poly_lma_hs_legacy" not in visible_names assert "template_pydream_poly_lma_hs_legacy" in all_names +def test_charged_monosq_template_spec_exposes_scaled_solvent_metadata(): + spec = load_template_spec(CHARGED_MONOSQ_TEMPLATE) + + assert spec.display_name == ( + "pyDREAM Charged MonoSQ Normalized (Scaled Solvent Weight)" + ) + assert [parameter.name for parameter in spec.parameters] == [ + "solv_w", + "offset", + "eff_r", + "vol_frac", + "charge", + "temperature", + "concentration_salt", + "dielectconst", + "scale", + ] + assert ( + spec.solution_scattering_support.volume_fraction_parameter + == "vol_frac" + ) + assert spec.solution_scattering_support.volume_fraction_kind == "solute" + assert ( + spec.solution_scattering_support.volume_fraction_source == "physical" + ) + assert ( + spec.solution_scattering_support.solvent_contribution_scale_mode + == "global_scale" + ) + assert spec.prefit_support.auto_apply_autoscale_on_load + assert spec.prefit_support.autoscale_bounds_mode == "adaptive" + + charge_entry = next( + parameter + for parameter in spec.parameters + if parameter.name == "charge" + ) + assert charge_entry.minimum == pytest.approx(1e-6) + assert charge_entry.maximum == pytest.approx(200.0) + + +def test_charged_monosq_hayter_msa_matches_sasview_reference_values(): + module = load_template_module(CHARGED_MONOSQ_TEMPLATE) + + q_values = np.asarray([0.00001, 0.0010, 0.01, 0.075]) + expected = np.asarray([0.0711646, 0.0712928, 0.0847006, 1.07150]) + + actual = module.calc_hayter_msa_sq( + 20.75, + 0.0192, + 19.0, + 298.0, + 0.0, + 78.0, + q_values, + ) + + np.testing.assert_allclose(actual, expected, rtol=2e-5, atol=1e-7) + + +def test_charged_monosq_lmfit_model_uses_scaled_solvent_convention(): + module = load_template_module(CHARGED_MONOSQ_TEMPLATE) + q_values = np.linspace(0.01, 0.08, 5) + solvent = np.linspace(1.0, 2.0, 5) + components = [ + np.linspace(10.0, 14.0, 5), + np.linspace(4.0, 8.0, 5), + ] + params = { + "w0": 0.25, + "w1": 0.75, + "solv_w": 0.5, + "offset": 0.2, + "eff_r": 20.75, + "vol_frac": 0.0192, + "charge": 19.0, + "temperature": 298.0, + "concentration_salt": 0.0, + "dielectconst": 78.0, + "scale": 2.0e-4, + } + + structure_factor = module.calc_hayter_msa_sq( + params["eff_r"], + params["vol_frac"], + params["charge"], + params["temperature"], + params["concentration_salt"], + params["dielectconst"], + q_values, + ) + mixture = params["w0"] * components[0] + params["w1"] * components[1] + expected = ( + params["scale"] + * (mixture * structure_factor + params["solv_w"] * solvent) + + params["offset"] + ) + + actual = module.lmfit_model_profile( + q_values, + solvent, + components, + **params, + ) + + np.testing.assert_allclose(actual, expected) + + def test_validate_template_candidate_passes_for_sphere_only_geometry_constraints( tmp_path, ): @@ -420,7 +534,8 @@ def log_likelihood_candidate(params): def test_install_template_candidate_copies_files_and_loads_spec(tmp_path): source_template = Path( - "src/saxshell/saxs/_model_templates/template_pd_likelihood_monosq.py" + "src/saxshell/saxs/_model_templates/_deprecated/" + "template_pd_likelihood_monosq.py" ) source_metadata = source_template.with_suffix(".json") candidate_template = tmp_path / source_template.name diff --git a/tests/test_saxs_ui.py b/tests/test_saxs_ui.py index c19c535..f6da8b6 100644 --- a/tests/test_saxs_ui.py +++ b/tests/test_saxs_ui.py @@ -25,6 +25,7 @@ QAbstractScrollArea, QApplication, QCheckBox, + QComboBox, QDialog, QFileDialog, QFormLayout, @@ -141,6 +142,9 @@ from saxshell.saxs.ui.experimental_data_loader import ( ExperimentalDataHeaderDialog, ) +from saxshell.saxs.ui.experimental_overlay_window import ( + ExperimentalDataOverlayWindow, +) from saxshell.saxs.ui.main_window import ( AUTO_SNAP_PANES_KEY, PACKMOL_DOCKER_PRESETS_KEY, @@ -2878,6 +2882,7 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): assert window.tools_menu.title() == "Tools" assert window.md_extraction_menu.title() == "MD Extraction" assert window.structure_analysis_menu.title() == "Structure Analysis" + assert window.batch_processing_menu.title() == "Batch Processing" assert window.visualization_menu.title() == "Visualization" assert window.cli_setup_menu.title() == "CLI Setup" assert window.beta_menu.title() == "(beta)" @@ -2898,6 +2903,7 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): "Structure Analysis", "Cluster Dynamics", "PDF", + "Batch Processing", "Visualization", "SAXS Calculation Preview", "X-ray Toolkit", @@ -2917,12 +2923,21 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): "Open Bond Analysis", "Open Representative Structures", ] + visualization_actions = [ + action.text() for action in window.visualization_menu.actions() + ] + assert visualization_actions == [ + "Structure Viewer", + "Experimental Data Overlay", + "Open Blender XYZ Renderer", + ] window._build_menu_bar() assert [action.text() for action in window.tools_menu.actions()] == [ "MD Extraction", "Structure Analysis", "Cluster Dynamics", "PDF", + "Batch Processing", "Visualization", "SAXS Calculation Preview", "X-ray Toolkit", @@ -2935,6 +2950,14 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): "Open Bond Analysis", "Open Representative Structures", ] + visualization_actions = [ + action.text() for action in window.visualization_menu.actions() + ] + assert visualization_actions == [ + "Structure Viewer", + "Experimental Data Overlay", + "Open Blender XYZ Renderer", + ] assert ( window.clusterdynamics_action.text() == "Open Cluster Dynamics (only)" ) @@ -2944,8 +2967,40 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): assert [ action.text() for action in window.cluster_dynamics_menu.actions() ] == ["Open Cluster Dynamics (ML)"] + assert [action.text() for action in window.pdf_menu.actions()] == [ + "Open PDF Calculation", + "Open RMC Setup (fullrmc)", + ] + assert ( + window.mdtrajectory_batch_queue_action.text() + == "Open MD Trajectory Batch Queue" + ) + assert ( + window.xyz2pdb_batch_queue_action.text() + == "Open XYZ -> PDB Batch Queue" + ) + assert ( + window.cluster_batch_queue_action.text() + == "Open Cluster Extraction Batch Queue" + ) + assert ( + window.representativefinder_batch_queue_action.text() + == "Open Representative Structures Batch Queue" + ) + assert [ + action.text() for action in window.batch_processing_menu.actions() + ] == [ + "Open MD Trajectory Batch Queue", + "Open XYZ -> PDB Batch Queue", + "Open Cluster Extraction Batch Queue", + "Open Representative Structures Batch Queue", + "Open PDF Batch Queue", + ] assert window.fullrmc_action.text() == "Open RMC Setup (fullrmc)" assert window.structure_viewer_action.text() == "Structure Viewer" + assert window.experimental_overlay_action.text() == ( + "Experimental Data Overlay" + ) assert window.blenderxyz_action.text() == "Open Blender XYZ Renderer" assert ( window.representative_finder_action.text() @@ -2977,11 +3032,31 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): window.solvent_shell_builder_action.text() == "Open Solvent Shell Builder (Beta)" ) + assert ( + window.xyz2pdb_cli_setup_action.text() + == "Open XYZ -> PDB CLI Setup (Beta)" + ) + assert ( + window.cluster_cli_setup_action.text() + == "Open Cluster Extraction CLI Setup (Beta)" + ) + assert ( + window.clusterdynamics_cli_setup_action.text() + == "Open Cluster Dynamics CLI Setup (Beta)" + ) + assert ( + window.clusterdynamicsml_cli_setup_action.text() + == "Open Cluster Dynamics ML CLI Setup (Beta)" + ) assert ( window.representative_cli_setup_action.text() == "Open Representative CLI Setup (Beta)" ) assert [action.text() for action in window.cli_setup_menu.actions()] == [ + "Open XYZ -> PDB CLI Setup (Beta)", + "Open Cluster Extraction CLI Setup (Beta)", + "Open Cluster Dynamics CLI Setup (Beta)", + "Open Cluster Dynamics ML CLI Setup (Beta)", "Open Representative CLI Setup (Beta)", ] assert [action.text() for action in window.beta_menu.actions()] == [ @@ -3482,6 +3557,123 @@ def fake_launch_mdtrajectory_app(**kwargs): window.close() +def test_mdtrajectory_batch_queue_uses_active_project_references( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + trajectory_file = tmp_path / "traj.xyz" + topology_file = tmp_path / "topology.pdb" + energy_file = tmp_path / "traj.ener" + trajectory_file.write_text("1\nframe\nPb 0.0 0.0 0.0\n", encoding="utf-8") + topology_file.write_text("MODEL 1\nENDMDL\n", encoding="utf-8") + energy_file.write_text( + "# step time kinetic temperature potential\n" + "1 0.0 1.0 300.0 -10.0\n", + encoding="utf-8", + ) + window.current_settings.trajectory_file = str(trajectory_file) + window.current_settings.topology_file = str(topology_file) + window.current_settings.energy_file = str(energy_file) + window.project_manager.save_project(window.current_settings) + launched: dict[str, object] = {} + + class FakeMDTrajectoryBatchQueueWindow(QWidget): + def __init__(self, **kwargs): + super().__init__() + launched.update(kwargs) + launched["instance"] = self + + def show(self): + launched["shown"] = True + + def raise_(self): + launched["raised"] = True + + monkeypatch.setattr( + "saxshell.mdtrajectory.ui.batch_queue_window." + "MDTrajectoryBatchQueueWindow", + FakeMDTrajectoryBatchQueueWindow, + ) + + window._open_mdtrajectory_batch_queue_tool() + + assert launched["initial_project_dir"] == Path(project_dir).resolve() + assert launched["initial_trajectory_file"] == trajectory_file.resolve() + assert launched["initial_topology_file"] == topology_file.resolve() + assert launched["initial_energy_file"] == energy_file.resolve() + assert launched["shown"] is True + assert launched["raised"] is True + assert launched["instance"] in window._child_tool_windows + window.close() + + +def test_mdtrajectory_batch_queue_updates_main_project_frames_dir_from_child( + qapp, + tmp_path, + monkeypatch, +): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + exported_frames_dir = tmp_path / "splitxyz_f12_t1000fs" + exported_frames_dir.mkdir() + + saved_settings = window.project_manager.load_project(project_dir) + saved_settings.frames_dir = str(exported_frames_dir.resolve()) + window.project_manager.save_project(saved_settings) + + class FakeMDTrajectoryBatchQueueWindow(QWidget): + project_paths_registered = Signal(object) + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def show(self): + return None + + def raise_(self): + return None + + fake_window: FakeMDTrajectoryBatchQueueWindow | None = None + + def fake_batch_queue_window(**kwargs): + nonlocal fake_window + fake_window = FakeMDTrajectoryBatchQueueWindow(**kwargs) + return fake_window + + monkeypatch.setattr( + "saxshell.mdtrajectory.ui.batch_queue_window." + "MDTrajectoryBatchQueueWindow", + fake_batch_queue_window, + ) + + window._open_mdtrajectory_batch_queue_tool() + assert fake_window is not None + fake_window.project_paths_registered.emit( + { + "project_dir": Path(project_dir).resolve(), + "frames_dir": exported_frames_dir.resolve(), + } + ) + qapp.processEvents() + + assert ( + window.project_setup_tab.frames_dir() == exported_frames_dir.resolve() + ) + assert window.current_settings is not None + assert window.current_settings.resolved_frames_dir == ( + exported_frames_dir.resolve() + ) + window.save_project_state() + saved_after_save = window.project_manager.load_project(project_dir) + assert saved_after_save.resolved_frames_dir == ( + exported_frames_dir.resolve() + ) + window.close() + + def test_debye_waller_tool_uses_active_project_clusters_dir( qapp, tmp_path, @@ -4055,6 +4247,103 @@ def raise_(self): window.close() +def test_cluster_batch_queue_uses_active_project_pdb_frames_dir_and_project_dir( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + frames_dir = tmp_path / "splitxyz_f0fs" + frames_dir.mkdir() + pdb_frames_dir = tmp_path / "xyz2pdb_splitxyz_f0fs" + pdb_frames_dir.mkdir() + window.current_settings.frames_dir = str(frames_dir) + window.current_settings.pdb_frames_dir = str(pdb_frames_dir) + window.project_manager.save_project(window.current_settings) + launched: dict[str, object] = {} + + class FakeClusterBatchWindow(QWidget): + project_paths_registered = Signal(object) + + def __init__(self, **kwargs): + super().__init__() + launched.update(kwargs) + launched["instance"] = self + + def show(self): + launched["shown"] = True + + def raise_(self): + launched["raised"] = True + + monkeypatch.setattr( + "saxshell.cluster.ui.batch_queue_window.ClusterBatchQueueWindow", + FakeClusterBatchWindow, + ) + + window._open_cluster_batch_queue_tool() + + assert launched["initial_frames_dir"] == pdb_frames_dir.resolve() + assert launched["initial_project_dir"] == Path(project_dir).resolve() + assert launched["shown"] is True + assert launched["raised"] is True + assert launched["instance"] in window._child_tool_windows + window.close() + + +def test_representativefinder_batch_queue_uses_active_project_clusters_dir( + qapp, + tmp_path, + monkeypatch, +): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + clusters_dir = tmp_path / "clusters_xyz2pdb_splitxyz_f0fs" + clusters_dir.mkdir() + window.current_settings.clusters_dir = str(clusters_dir) + window.project_manager.save_project(window.current_settings) + launched: dict[str, object] = {} + refreshed_projects: list[str] = [] + window._handle_representative_structure_results_changed = ( + refreshed_projects.append + ) + + class FakeRepresentativeBatchWindow(QWidget): + project_results_changed = Signal(str) + + def __init__(self, **kwargs): + super().__init__() + launched.update(kwargs) + launched["instance"] = self + + def show(self): + launched["shown"] = True + + def raise_(self): + launched["raised"] = True + + monkeypatch.setattr( + "saxshell.representativefinder.ui.batch_queue_window." + "RepresentativeFinderBatchQueueWindow", + FakeRepresentativeBatchWindow, + ) + + window._open_representative_batch_queue_tool() + project_key = str(Path(project_dir).resolve()) + launched["instance"].project_results_changed.emit(project_key) + qapp.processEvents() + + assert launched["initial_clusters_dir"] == clusters_dir.resolve() + assert launched["initial_project_dir"] == Path(project_dir).resolve() + assert launched["shown"] is True + assert launched["raised"] is True + assert launched["instance"] in window._child_tool_windows + assert refreshed_projects == [project_key] + window.close() + + def test_xyz2pdb_tool_uses_active_project_frames_dir_and_project_dir( qapp, tmp_path, monkeypatch ): @@ -4088,6 +4377,49 @@ def fake_launch_xyz2pdb_ui(**kwargs): window.close() +def test_xyz2pdb_batch_queue_uses_active_project_frames_dir_and_project_dir( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + frames_dir = tmp_path / "splitxyz_f0fs" + frames_dir.mkdir() + window.current_settings.frames_dir = str(frames_dir) + window.project_manager.save_project(window.current_settings) + launched: dict[str, object] = {} + + class FakeXYZToPDBBatchWindow(QWidget): + project_paths_registered = Signal(object) + + def __init__(self, **kwargs): + super().__init__() + launched.update(kwargs) + launched["instance"] = self + + def show(self): + launched["shown"] = True + + def raise_(self): + launched["raised"] = True + + monkeypatch.setattr( + "saxshell.xyz2pdb.ui.batch_queue_window.XYZToPDBBatchQueueWindow", + FakeXYZToPDBBatchWindow, + ) + + window._open_xyz2pdb_batch_queue_tool() + + assert launched["initial_input_path"] == frames_dir.resolve() + assert launched["initial_project_dir"] == Path(project_dir).resolve() + assert launched["shown"] is True + assert launched["raised"] is True + assert launched["instance"] in window._child_tool_windows + window.close() + + def test_mdtrajectory_tool_updates_main_project_frames_dir_from_child( qapp, tmp_path, @@ -4182,6 +4514,100 @@ def fake_launch_xyz2pdb_ui(**kwargs): window.close() +def test_xyz2pdb_batch_queue_updates_main_project_pdb_folder_from_child( + qapp, + tmp_path, + monkeypatch, +): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + pdb_frames_dir = tmp_path / "xyz2pdb_splitxyz_f5fs" + pdb_frames_dir.mkdir() + saved_settings = window.project_manager.load_project(project_dir) + saved_settings.pdb_frames_dir = str(pdb_frames_dir.resolve()) + window.project_manager.save_project(saved_settings) + + class FakeXYZToPDBBatchWindow(QWidget): + project_paths_registered = Signal(object) + + def show(self): + return + + def raise_(self): + return + + fake_window = FakeXYZToPDBBatchWindow() + + monkeypatch.setattr( + "saxshell.xyz2pdb.ui.batch_queue_window.XYZToPDBBatchQueueWindow", + lambda **kwargs: fake_window, + ) + + window._open_xyz2pdb_batch_queue_tool() + fake_window.project_paths_registered.emit( + { + "project_dir": Path(project_dir).resolve(), + "pdb_frames_dir": pdb_frames_dir.resolve(), + } + ) + qapp.processEvents() + + assert ( + window.project_setup_tab.pdb_frames_dir() == pdb_frames_dir.resolve() + ) + assert window.current_settings is not None + assert window.current_settings.resolved_pdb_frames_dir == ( + pdb_frames_dir.resolve() + ) + window.close() + + +def test_cluster_batch_queue_updates_main_project_clusters_folder_from_child( + qapp, + tmp_path, + monkeypatch, +): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + clusters_dir = tmp_path / "clusters_xyz2pdb_splitxyz_f5fs" + clusters_dir.mkdir() + saved_settings = window.project_manager.load_project(project_dir) + saved_settings.clusters_dir = str(clusters_dir.resolve()) + window.project_manager.save_project(saved_settings) + + class FakeClusterBatchWindow(QWidget): + project_paths_registered = Signal(object) + + def show(self): + return + + def raise_(self): + return + + fake_window = FakeClusterBatchWindow() + + monkeypatch.setattr( + "saxshell.cluster.ui.batch_queue_window.ClusterBatchQueueWindow", + lambda **kwargs: fake_window, + ) + + window._open_cluster_batch_queue_tool() + fake_window.project_paths_registered.emit( + { + "project_dir": Path(project_dir).resolve(), + "clusters_dir": clusters_dir.resolve(), + } + ) + qapp.processEvents() + + assert window.project_setup_tab.clusters_dir() == clusters_dir.resolve() + assert window.current_settings is not None + assert window.current_settings.resolved_clusters_dir == ( + clusters_dir.resolve() + ) + window.close() + + def test_cluster_tool_updates_main_project_folder_refs_from_child( qapp, tmp_path, @@ -4252,6 +4678,148 @@ def fake_cluster_window(*args, **kwargs): window.close() +def test_project_writing_tool_launches_reuse_existing_instances( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + launches = { + "mdtrajectory": 0, + "xyz2pdb": 0, + "cluster": 0, + "representativefinder": 0, + } + + class FakeToolWindow(QWidget): + def __init__(self, tool_key: str, *args, **kwargs): + super().__init__() + del args, kwargs + self.tool_key = tool_key + self.show_count = 0 + self.raise_count = 0 + self.activate_count = 0 + + def show(self): + self.show_count += 1 + + def raise_(self): + self.raise_count += 1 + + def activateWindow(self): + self.activate_count += 1 + + fake_windows = {key: FakeToolWindow(key) for key in launches} + + def fake_launch_mdtrajectory_app(**kwargs): + del kwargs + launches["mdtrajectory"] += 1 + return fake_windows["mdtrajectory"] + + def fake_launch_xyz2pdb_ui(**kwargs): + del kwargs + launches["xyz2pdb"] += 1 + return fake_windows["xyz2pdb"] + + def fake_launch_representativefinder_ui(**kwargs): + del kwargs + launches["representativefinder"] += 1 + return fake_windows["representativefinder"] + + def fake_cluster_window(*args, **kwargs): + del args, kwargs + launches["cluster"] += 1 + return fake_windows["cluster"] + + monkeypatch.setattr( + "saxshell.mdtrajectory.ui.main_window.launch_mdtrajectory_app", + fake_launch_mdtrajectory_app, + ) + monkeypatch.setattr( + "saxshell.xyz2pdb.ui.main_window.launch_xyz2pdb_ui", + fake_launch_xyz2pdb_ui, + ) + monkeypatch.setattr( + "saxshell.cluster.ui.main_window.ClusterMainWindow", + fake_cluster_window, + ) + monkeypatch.setattr( + "saxshell.representativefinder.ui.main_window." + "launch_representativefinder_ui", + fake_launch_representativefinder_ui, + ) + + launch_methods = { + "mdtrajectory": window._open_mdtrajectory_tool, + "xyz2pdb": window._open_xyz2pdb_tool, + "cluster": window._open_cluster_tool, + "representativefinder": window._open_representative_finder_tool, + } + + for tool_key, open_tool in launch_methods.items(): + open_tool() + open_tool() + + assert launches[tool_key] == 1 + assert fake_windows[tool_key] in window._child_tool_windows + assert fake_windows[tool_key].show_count >= 1 + assert fake_windows[tool_key].raise_count >= 1 + assert fake_windows[tool_key].activate_count >= 1 + + assert "already open" in window.statusBar().currentMessage().lower() + window.close() + + +def test_pdf_tools_are_mutually_exclusive(qapp, tmp_path): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + + class FakeToolWindow(QWidget): + def __init__(self): + super().__init__() + self.show_count = 0 + self.raise_count = 0 + self.activate_count = 0 + + def show(self): + self.show_count += 1 + + def raise_(self): + self.raise_count += 1 + + def activateWindow(self): + self.activate_count += 1 + + batch_window = FakeToolWindow() + window._track_child_tool_window( + batch_window, + single_instance_key="pdf_batch_queue", + ) + window._open_pdfsetup_tool() + assert "pdf batch queue is already open" in ( + window.statusBar().currentMessage().lower() + ) + assert batch_window.show_count >= 1 + assert "pdfsetup" not in window._single_instance_child_tool_windows + + window._forget_child_tool_window(batch_window) + pdf_window = FakeToolWindow() + window._track_child_tool_window( + pdf_window, + single_instance_key="pdfsetup", + ) + window._open_pdf_batch_queue_tool() + assert "pdf calculation is already open" in ( + window.statusBar().currentMessage().lower() + ) + assert pdf_window.show_count >= 1 + assert "pdf_batch_queue" not in window._single_instance_child_tool_windows + window.close() + + def test_main_window_refuses_close_when_child_tool_refuses_close( qapp, tmp_path, @@ -4307,6 +4875,37 @@ def raise_(self): assert launched["instance"] in window._child_tool_windows +def test_experimental_overlay_tool_opens_from_visualization_menu( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + launched: dict[str, object] = {} + + class FakeExperimentalOverlayWindow(QWidget): + def __init__(self): + super().__init__() + launched["instance"] = self + + def fake_launch_experimental_data_overlay_ui(): + return FakeExperimentalOverlayWindow() + + monkeypatch.setattr( + "saxshell.saxs.ui.experimental_overlay_window." + "launch_experimental_data_overlay_ui", + fake_launch_experimental_data_overlay_ui, + ) + + window._open_experimental_data_overlay_tool() + + assert launched["instance"] in window._child_tool_windows + assert "experimental data overlay" in window.statusBar().currentMessage() + window.close() + + def test_cluster_dynamics_tool_uses_active_project_dir( qapp, tmp_path, monkeypatch ): @@ -4431,6 +5030,122 @@ def raise_(self): window.close() +def test_cluster_dynamics_cli_setup_uses_active_project_dir( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + frames_dir = tmp_path / "splitxyz_f0fs" + frames_dir.mkdir() + energy_file = tmp_path / "traj.ener" + energy_file.write_text( + "# step time kinetic temperature potential\n" + "1 0.0 1.0 300.0 -10.0\n", + encoding="utf-8", + ) + window.current_settings.frames_dir = str(frames_dir) + window.current_settings.energy_file = str(energy_file) + window.project_manager.save_project(window.current_settings) + launched: dict[str, object] = {} + + class FakeRunFileWindow: + pass + + def fake_launch_clusterdynamics_run_file_ui( + *, + initial_project_dir=None, + initial_frames_dir=None, + initial_energy_file=None, + ): + launched["project_dir"] = initial_project_dir + launched["frames_dir"] = initial_frames_dir + launched["energy_file"] = initial_energy_file + launched["instance"] = FakeRunFileWindow() + return launched["instance"] + + monkeypatch.setattr( + "saxshell.clusterdynamics.ui.run_file_window." + "launch_clusterdynamics_run_file_ui", + fake_launch_clusterdynamics_run_file_ui, + ) + + window._open_clusterdynamics_cli_setup_tool() + + assert launched["frames_dir"] == frames_dir.resolve() + assert launched["energy_file"] == energy_file.resolve() + assert ( + launched["project_dir"] + == Path(window.current_settings.project_dir).resolve() + ) + assert launched["instance"] in window._child_tool_windows + window.close() + + +def test_cluster_dynamics_ml_cli_setup_uses_active_project_dir( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + frames_dir = tmp_path / "splitxyz_f0fs" + frames_dir.mkdir() + clusters_dir = tmp_path / "clusters" + clusters_dir.mkdir() + energy_file = tmp_path / "traj.ener" + energy_file.write_text( + "# step time kinetic temperature potential\n" + "1 0.0 1.0 300.0 -10.0\n", + encoding="utf-8", + ) + window.current_settings.frames_dir = str(frames_dir) + window.current_settings.clusters_dir = str(clusters_dir) + window.current_settings.energy_file = str(energy_file) + window.project_manager.save_project(window.current_settings) + launched: dict[str, object] = {} + + class FakeRunFileWindow: + pass + + def fake_launch_clusterdynamicsml_run_file_ui( + *, + initial_project_dir=None, + initial_frames_dir=None, + initial_energy_file=None, + initial_clusters_dir=None, + initial_experimental_data_file=None, + ): + launched["project_dir"] = initial_project_dir + launched["frames_dir"] = initial_frames_dir + launched["energy_file"] = initial_energy_file + launched["clusters_dir"] = initial_clusters_dir + launched["experimental_data_file"] = initial_experimental_data_file + launched["instance"] = FakeRunFileWindow() + return launched["instance"] + + monkeypatch.setattr( + "saxshell.clusterdynamicsml.ui.run_file_window." + "launch_clusterdynamicsml_run_file_ui", + fake_launch_clusterdynamicsml_run_file_ui, + ) + + window._open_clusterdynamicsml_cli_setup_tool() + + assert launched["frames_dir"] == frames_dir.resolve() + assert launched["energy_file"] == energy_file.resolve() + assert launched["clusters_dir"] == clusters_dir.resolve() + assert ( + launched["experimental_data_file"] + == window.current_settings.resolved_experimental_data_path + ) + assert ( + launched["project_dir"] + == Path(window.current_settings.project_dir).resolve() + ) + assert launched["instance"] in window._child_tool_windows + window.close() + + def test_project_setup_predict_structures_button_opens_cluster_dynamics_ml_tool( qapp, tmp_path, @@ -13254,6 +13969,105 @@ def test_load_experimental_data_file_detects_three_column_headers( assert np.allclose(summary.errors, [0.1, 0.2]) +def test_experimental_overlay_window_loads_multiple_header_styles( + qapp, + tmp_path, +): + del qapp + plain_path = tmp_path / "plain_trace.dat" + plain_path.write_text( + "# q intensity error\n" + "0.05 10.0 0.5\n" + "0.10 9.0 0.4\n" + "0.20 9.5 0.3\n", + encoding="utf-8", + ) + header_path = tmp_path / "instrument_export.txt" + header_path.write_text( + "Detector export\n" + "q_value intensity_value sigma_value\n" + "0.04 100.0 2.0\n" + "0.10 105.0 2.2\n" + "0.30 110.0 2.5\n", + encoding="utf-8", + ) + + window = ExperimentalDataOverlayWindow() + added = window.add_data_files([plain_path, header_path]) + + assert added == 2 + assert len(window.traces) == 2 + assert window.trace_table.rowCount() == 2 + assert window.q_min_spin.value() == pytest.approx(0.04) + assert window.q_max_spin.value() == pytest.approx(0.30) + assert "q_value" in window.trace_table.item(1, 6).text() + assert "intensity_value" in window.trace_table.item(1, 6).text() + assert len(window._left_axis.get_lines()) == 2 + assert window.log_x_axis_button.isChecked() is True + assert window.log_y_axis_button.isChecked() is True + assert window.log_x_axis_button.text() == "Log X: On" + assert window.log_y_axis_button.text() == "Log Y: On" + assert window._left_axis.get_xscale() == "log" + assert window._left_axis.get_yscale() == "log" + + axis_combo = window.trace_table.cellWidget(1, window.AXIS_COLUMN) + assert isinstance(axis_combo, QComboBox) + axis_combo.setCurrentIndex(axis_combo.findData("right")) + + assert window.traces[1].axis == "right" + assert window._right_axis is not None + assert len(window._right_axis.get_lines()) == 1 + assert window._right_axis.get_xscale() == "log" + assert window._right_axis.get_yscale() == "log" + + window.log_x_axis_button.setChecked(False) + window.log_y_axis_button.setChecked(False) + + assert window.log_x_axis_button.text() == "Log X: Off" + assert window.log_y_axis_button.text() == "Log Y: Off" + assert window._left_axis.get_xscale() == "linear" + assert window._left_axis.get_yscale() == "linear" + assert window._right_axis is not None + assert window._right_axis.get_xscale() == "linear" + assert window._right_axis.get_yscale() == "linear" + + window.align_y_axes_checkbox.setChecked(False) + window.rescale_axes_button.click() + + assert window.align_y_axes_checkbox.isChecked() is True + + right_ylim_before_q_range = tuple(window._right_axis.get_ylim()) + + window.full_q_range_checkbox.setChecked(False) + window.q_min_spin.setValue(0.08) + window.q_max_spin.setValue(0.20) + + assert window._left_axis.get_xlim() == pytest.approx((0.08, 0.20)) + assert window._right_axis is not None + assert not np.allclose( + window._right_axis.get_ylim(), + right_ylim_before_q_range, + ) + + window._set_trace_color(0, "#123456") + + assert window.traces[0].color == "#123456" + assert window.trace_table.item(0, window.COLOR_COLUMN).text() == "#123456" + assert window._left_axis.get_lines()[0].get_color() == "#123456" + + visible_item = window.trace_table.item(0, window.SHOW_COLUMN) + visible_item.setCheckState(Qt.CheckState.Unchecked) + + assert window.traces[0].visible is False + + window.trace_table.selectRow(0) + window._remove_selected_traces() + + assert len(window.traces) == 1 + assert window.trace_table.rowCount() == 1 + window.close() + + def test_experimental_data_header_dialog_allows_manual_column_selection( qapp, tmp_path ): diff --git a/tests/test_xyz2pdb_cli.py b/tests/test_xyz2pdb_cli.py index 8904391..752ae05 100644 --- a/tests/test_xyz2pdb_cli.py +++ b/tests/test_xyz2pdb_cli.py @@ -3,10 +3,22 @@ import json from pathlib import Path +from saxshell.saxs.project_manager import SAXSProjectManager from saxshell.saxshell import main as saxshell_main from saxshell.structure import PDBAtom, PDBStructure from saxshell.xyz2pdb import XYZToPDBWorkflow from saxshell.xyz2pdb.cli import main as xyz2pdb_main +from saxshell.xyz2pdb.mapping_workflow import ( + FreeAtomMappingInput, + MoleculeMappingInput, +) +from saxshell.xyz2pdb.run_config import ( + build_xyz2pdb_run_config, + default_xyz2pdb_run_file_path, + load_xyz2pdb_run_config, + run_xyz2pdb_run_config, + save_xyz2pdb_run_config, +) def _write_reference_pdb(path: Path) -> None: @@ -142,6 +154,118 @@ def test_xyz2pdb_cli_export_runs_complete_headless_workflow( assert "Files written: 2" in captured.out +def test_xyz2pdb_run_config_round_trips_project_relative_paths(tmp_path): + refs_dir = tmp_path / "references" + frames_dir = tmp_path / "frames" + output_dir = tmp_path / "pdb_frames" + config = build_xyz2pdb_run_config( + project_dir=tmp_path, + input_path=frames_dir, + output_dir=output_dir, + reference_library_dir=refs_dir, + molecule_inputs=( + MoleculeMappingInput(reference_name="pbi", residue_name="PBI"), + ), + free_atom_inputs=( + FreeAtomMappingInput(element="O", residue_name="SOL"), + ), + pbc_params={"a": 20.0, "space_group": "P 1"}, + ) + run_file = default_xyz2pdb_run_file_path(tmp_path) + + save_xyz2pdb_run_config(run_file, config) + loaded = load_xyz2pdb_run_config(run_file) + + assert loaded.input_path == "frames" + assert loaded.output_dir == "pdb_frames" + assert loaded.reference_library_dir == "references" + assert loaded.molecule_inputs[0].reference_name == "pbi" + assert loaded.free_atom_inputs[0].residue_name == "SOL" + assert loaded.pbc_params == {"a": 20.0, "space_group": "P 1"} + + +def test_xyz2pdb_project_run_updates_project_pdb_frames_dir(tmp_path): + manager = SAXSProjectManager() + project_dir = tmp_path / "project" + manager.create_project(project_dir) + + refs_dir = project_dir / "references" + refs_dir.mkdir() + _write_reference_pdb(refs_dir / "pbi.pdb") + + frames_dir = project_dir / "frames" + frames_dir.mkdir() + _write_xyz(frames_dir / "frame_0000.xyz", i_x=1.0, oxygen_x=2.0) + _write_xyz(frames_dir / "frame_0001.xyz", i_x=1.1, oxygen_x=2.1) + + output_dir = project_dir / "pdb_frames" + config = build_xyz2pdb_run_config( + project_dir=project_dir, + input_path=frames_dir, + output_dir=output_dir, + reference_library_dir=refs_dir, + molecule_inputs=( + MoleculeMappingInput(reference_name="pbi", residue_name="PBI"), + ), + free_atom_inputs=( + FreeAtomMappingInput(element="O", residue_name="SOL"), + ), + ) + run_file = default_xyz2pdb_run_file_path(project_dir) + save_xyz2pdb_run_config(run_file, config) + + summary = run_xyz2pdb_run_config( + project_dir, + load_xyz2pdb_run_config(run_file), + run_file_path=run_file, + ) + + saved_settings = manager.load_project(project_dir) + assert summary.written_count == 2 + assert saved_settings.resolved_pdb_frames_dir == output_dir.resolve() + assert saved_settings.pdb_frames_dir_snapshot is not None + assert (output_dir / "frame_0000.pdb").is_file() + + +def test_xyz2pdb_cli_project_run_uses_project_default_run_file( + tmp_path, + capsys, +): + manager = SAXSProjectManager() + project_dir = tmp_path / "project" + manager.create_project(project_dir) + + refs_dir = project_dir / "references" + refs_dir.mkdir() + _write_reference_pdb(refs_dir / "pbi.pdb") + frames_dir = project_dir / "frames" + frames_dir.mkdir() + _write_xyz(frames_dir / "frame_0000.xyz", i_x=1.0, oxygen_x=2.0) + + save_xyz2pdb_run_config( + default_xyz2pdb_run_file_path(project_dir), + build_xyz2pdb_run_config( + project_dir=project_dir, + input_path=frames_dir, + output_dir=project_dir / "pdb_frames", + reference_library_dir=refs_dir, + molecule_inputs=( + MoleculeMappingInput(reference_name="pbi", residue_name="PBI"), + ), + free_atom_inputs=( + FreeAtomMappingInput(element="O", residue_name="SOL"), + ), + ), + ) + + exit_code = xyz2pdb_main(["run", str(project_dir)]) + output = capsys.readouterr().out + + assert exit_code == 0 + assert "XYZ to PDB project run complete." in output + assert "Files written: 1" in output + + def test_xyz2pdb_reference_cli_and_saxshell_forwarding(tmp_path, capsys): refs_dir = tmp_path / "references" refs_dir.mkdir() diff --git a/tests/test_xyz2pdb_ui.py b/tests/test_xyz2pdb_ui.py index d7c0c85..1cb53ed 100644 --- a/tests/test_xyz2pdb_ui.py +++ b/tests/test_xyz2pdb_ui.py @@ -17,7 +17,18 @@ from saxshell.saxs.project_manager import SAXSProjectManager from saxshell.structure import PDBAtom, PDBStructure +from saxshell.xyz2pdb.mapping_workflow import ( + FreeAtomMappingInput, + MoleculeMappingInput, +) +from saxshell.xyz2pdb.ui.batch_queue_window import ( + XYZToPDBBatchItem, + XYZToPDBBatchQueueWindow, + XYZToPDBBatchResult, + XYZToPDBBatchWorker, +) from saxshell.xyz2pdb.ui.main_window import XYZToPDBMainWindow +from saxshell.xyz2pdb.ui.run_file_window import XYZToPDBRunFileWindow from saxshell.xyz2pdb.workflow import ( XYZToPDBAssertionResidueSummary, XYZToPDBAssertionResult, @@ -294,6 +305,198 @@ def test_xyz2pdb_export_registers_pdb_folder_with_project(qapp, tmp_path): window.close() +def test_xyz2pdb_batch_queue_prefills_project_xyz_reference_and_elements( + qapp, + tmp_path, +): + del qapp + manager = SAXSProjectManager() + project_dir = tmp_path / "saxs_project" + settings = manager.create_project(project_dir) + refs_dir = tmp_path / "references" + refs_dir.mkdir() + _write_reference_pdb(refs_dir / "pbi.pdb") + frames_dir = tmp_path / "splitxyz_f0fs" + frames_dir.mkdir() + _write_xyz(frames_dir / "frame_0000.xyz", i_x=1.0, oxygen_x=2.0) + settings.frames_dir = str(frames_dir) + manager.save_project(settings) + + window = XYZToPDBBatchQueueWindow( + initial_project_dir=project_dir, + reference_library_dir=refs_dir, + ) + + assert window.queue_list.count() == 1 + list_item = window.queue_list.item(0) + item_id = str(list_item.data(Qt.ItemDataRole.UserRole)) + widget = window._widgets_by_id[item_id] + assert widget.input_path_edit.text() == str(frames_dir.resolve()) + assert { + widget.free_element_combo.itemData(index) + for index in range(widget.free_element_combo.count()) + } == {"I", "O", "Pb"} + assert { + widget.reference_combo.itemData(index) + for index in range(widget.reference_combo.count()) + } == {"pbi"} + assert "Elements: I x1, O x1, Pb x1" in ( + widget.analysis_summary_label.text() + ) + window.close() + + +def test_xyz2pdb_batch_worker_exports_and_registers_pdb_folder( + qapp, + tmp_path, +): + del qapp + manager = SAXSProjectManager() + project_dir = tmp_path / "saxs_project" + settings = manager.create_project(project_dir) + refs_dir = tmp_path / "references" + refs_dir.mkdir() + _write_reference_pdb(refs_dir / "pbi.pdb") + frames_dir = tmp_path / "splitxyz_f0fs" + frames_dir.mkdir() + _write_xyz(frames_dir / "frame_0000.xyz", i_x=1.0, oxygen_x=2.0) + settings.frames_dir = str(frames_dir) + manager.save_project(settings) + + worker = XYZToPDBBatchWorker( + [ + ( + "job-1", + XYZToPDBBatchItem( + item_id="job-1", + project_dir=project_dir, + input_path=frames_dir, + reference_library_dir=refs_dir, + molecule_inputs=( + MoleculeMappingInput( + reference_name="pbi", + residue_name="PBI", + ), + ), + free_atom_inputs=( + FreeAtomMappingInput( + element="O", + residue_name="SOL", + ), + ), + ).to_job(), + ) + ] + ) + failures: list[tuple[str, str]] = [] + finished_items: list[tuple[str, XYZToPDBBatchResult]] = [] + finished_batches: list[list[XYZToPDBBatchResult]] = [] + worker.failed.connect( + lambda item_id, message: failures.append((item_id, message)) + ) + worker.item_finished.connect( + lambda item_id, result: finished_items.append((item_id, result)) + ) + worker.finished.connect(finished_batches.append) + + worker.run() + + assert failures == [] + assert len(finished_items) == 1 + item_id, result = finished_items[0] + assert item_id == "job-1" + assert result.output_dir.name == "xyz2pdb_splitxyz_f0fs" + assert result.written_count == 1 + assert (result.output_dir / "frame_0000.pdb").is_file() + saved_settings = manager.load_project(project_dir) + assert saved_settings.resolved_frames_dir == frames_dir.resolve() + assert saved_settings.resolved_pdb_frames_dir == result.output_dir + assert saved_settings.pdb_frames_dir_snapshot is not None + assert finished_batches == [[result]] + + +def test_xyz2pdb_batch_queue_emits_registered_pdb_folder(qapp, tmp_path): + del qapp + project_dir = tmp_path / "saxs_project" + SAXSProjectManager().create_project(project_dir) + output_dir = tmp_path / "xyz2pdb_splitxyz_f0fs" + output_dir.mkdir() + input_dir = tmp_path / "splitxyz_f0fs" + input_dir.mkdir() + + window = XYZToPDBBatchQueueWindow() + widget = window.add_queue_item( + XYZToPDBBatchItem( + item_id="job-1", + project_dir=project_dir, + input_path=input_dir, + ) + ) + updates = [] + window.project_paths_registered.connect(updates.append) + + window._on_item_finished( + widget.item_id, + XYZToPDBBatchResult( + project_dir=project_dir.resolve(), + input_path=input_dir.resolve(), + output_dir=output_dir.resolve(), + written_count=1, + ), + ) + + assert updates == [ + { + "project_dir": project_dir.resolve(), + "pdb_frames_dir": output_dir.resolve(), + } + ] + assert widget.status_label.text() == "Complete" + window.close() + + +def test_xyz2pdb_run_file_window_builds_project_config(qapp, tmp_path): + del qapp + project_dir = tmp_path / "project" + SAXSProjectManager().create_project(project_dir) + + refs_dir = project_dir / "references" + refs_dir.mkdir() + _write_reference_pdb(refs_dir / "pbi.pdb") + + frames_dir = project_dir / "frames" + frames_dir.mkdir() + _write_xyz(frames_dir / "frame_0000.xyz", i_x=1.0, oxygen_x=2.0) + + window = XYZToPDBRunFileWindow( + initial_project_dir=project_dir, + initial_input_path=frames_dir, + ) + window.reference_panel.library_dir_edit.setText(str(refs_dir)) + window.refresh_reference_library() + window.mapping_panel._molecule_inputs = [ + MoleculeMappingInput(reference_name="pbi", residue_name="PBI") + ] + window.mapping_panel._refresh_molecule_table() + window.mapping_panel.set_available_elements(("O", "Pb", "I")) + window.mapping_panel.free_element_combo.setCurrentText("O") + window.mapping_panel.free_residue_edit.setText("SOL") + window.mapping_panel._add_free_atom() + window.output_dir_edit.setText(str(project_dir / "pdb_frames")) + window.pbc_params_edit.setPlainText('{"a": 20.0, "b": 21.0, "c": 22.0}') + + config = window._current_config(project_dir) + + assert config.input_path == "frames" + assert config.output_dir == "pdb_frames" + assert config.reference_library_dir == "references" + assert config.molecule_inputs[0].reference_name == "pbi" + assert config.free_atom_inputs[0].element == "O" + assert config.pbc_params["a"] == 20.0 + assert "xyz2pdb run" in window.command_box.toPlainText() + window.close() + + def test_main_window_native_mapping_flow_estimates_and_reports_export_progress_without_json( qapp, tmp_path,