Skip to content

Commit c032d76

Browse files
refactor: modularize training pipeline and stabilize model initialization
- Refactor train.py by extracting logic into training submodules: - arguments.py: CLI parameter definitions - builder.py: Model and criterion setup - checkpoint.py: Robust saving and loading logic - Fix learning rate plateau by replacing disjoint schedulers with SequentialLR to properly chain linear warmup and cosine decay phases. - Simplify SpatialTranscriptFormer architecture: - Remove redundant log_temperature parameter to reduce gradient variance. - Implement L1-normalization for MSigDB pathway weight initialization to prevent exponential prediction explosion at startup. - Enhance load_checkpoint with robust error handling for EOFError (corrupted files) and ValueError (architecture/optimizer mismatches) to ensure graceful fallbacks.
1 parent 63d0912 commit c032d76

File tree

16 files changed

+1225
-1499
lines changed

16 files changed

+1225
-1499
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ stf-download --species "Homo sapiens" --local_dir hest_data
6262
We provide presets for baseline models and scaled versions of the SpatialTranscriptFormer.
6363

6464
```bash
65-
# Recommended: Run the Interaction model with 4 transformer layers
66-
python scripts/run_preset.py --preset stf_interaction_l4
65+
# Recommended: Run the Interaction model (Small)
66+
python scripts/run_preset.py --preset stf_small
6767

68-
# Run the lightweight 2-layer version
69-
python scripts/run_preset.py --preset stf_interaction_l2
68+
# Run the lightweight Tiny version
69+
python scripts/run_preset.py --preset stf_tiny
7070

7171
# Run baselines
7272
python scripts/run_preset.py --preset he2rna_baseline

docs/MODELS.md

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,8 @@ The SpatialTranscriptFormer models the **interaction between biological pathways
3434
By default, the model operates in **Full Interaction** mode where all four information flows are active. Users can selectively disable any combination using the `--interactions` flag to explore architectural variants:
3535

3636
```bash
37-
# Default: Full Interaction (all quadrants enabled)
38-
--interactions p2p p2h h2p h2h
39-
40-
# Pathway Bottleneck: block H↔H to force all inter-patch
41-
# communication through the pathway bottleneck
42-
--interactions p2p p2h h2p
37+
# Default: Small Interaction (CTransPath, 4 layers)
38+
python scripts/run_preset.py --preset stf_small
4339
```
4440

4541
> [!TIP]
@@ -53,7 +49,7 @@ Three additional design principles support these interactions:
5349

5450
- **Biological Initialisation** — The gene reconstruction weights are initialised from MSigDB Hallmark gene sets, providing a biologically-grounded starting point that the model refines during training.
5551

56-
### 2.2 Spatial Learning
52+
## 2.2 Spatial Learning
5753

5854
The spatial relationships of gene expression are central to this model. It is not sufficient to predict correct expression magnitudes at each spot independently — the model must capture **where** on the tissue pathways are active and how that spatial pattern varies across the slide. Two mechanisms enforce this:
5955

docs/TRAINING_GUIDE.md

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -147,23 +147,22 @@ python -m spatial_transcript_former.train \
147147

148148
> **Note**: Without `--pathway-init`, the model disables the `AuxiliaryPathwayLoss` and relies entirely on the main reconstruction objectives and the L1 sparsity penalty. (I am yet to obtain results with this method)...
149149
150-
### Robust Counting: ZINB + Auxiliary Loss
150+
### Recommended: Using Presets
151151

152-
For raw count data with high sparsity, using the ZINB distribution and auxiliary pathway supervision is recommended.
152+
For most cases, it is recommended to use the provided presets:
153153

154154
```bash
155-
python -m spatial_transcript_former.train \
156-
--data-dir A:\hest_data \
157-
--model interaction \
158-
--backbone ctranspath \
159-
--pathway-init \
160-
--loss zinb \
161-
--pathway-loss-weight 0.5 \
162-
--lr 5e-5 \
163-
--batch-size 4 \
164-
--whole-slide \
165-
--precomputed \
166-
--epochs 200
155+
# Tiny (2 layers, 256 dim)
156+
python scripts/run_preset.py --preset stf_tiny
157+
158+
# Small (4 layers, 384 dim) - Recommended
159+
python scripts/run_preset.py --preset stf_small
160+
161+
# Medium (6 layers, 512 dim)
162+
python scripts/run_preset.py --preset stf_medium
163+
164+
# Large (12 layers, 768 dim)
165+
python scripts/run_preset.py --preset stf_large
167166
```
168167

169168
### Choosing Interaction Modes
@@ -201,7 +200,7 @@ Submit with:
201200
sbatch hpc/array_train.slurm
202201
```
203202

204-
### Collecting Results
203+
### Collecting Results (Currently broken!)
205204

206205
After experiments complete, aggregate all `results_summary.json` files into a comparison table:
207206

scripts/inspect_outputs.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import torch
2+
import json
3+
import os
4+
import argparse
5+
import numpy as np
6+
from spatial_transcript_former.models import SpatialTranscriptFormer
7+
from spatial_transcript_former.data.utils import get_sample_ids, setup_dataloaders
8+
9+
10+
class Args:
11+
pass
12+
13+
14+
args = Args()
15+
args.data_dir = "A:\\hest_data"
16+
args.epochs = 2000
17+
args.output_dir = "runs/stf_tiny"
18+
args.model = "interaction"
19+
args.backbone = "ctranspath"
20+
args.precomputed = True
21+
args.whole_slide = True
22+
args.pathway_init = True
23+
args.use_amp = True
24+
args.log_transform = True
25+
args.loss = "mse_pcc"
26+
args.resume = True
27+
args.n_layers = 2
28+
args.token_dim = 256
29+
args.n_heads = 4
30+
args.batch_size = 1
31+
args.vis_sample = "TENX29"
32+
args.max_samples = 1
33+
args.organ = None
34+
args.num_genes = 1000
35+
args.n_neighbors = 6
36+
args.use_global_context = False
37+
args.global_context_size = 0
38+
args.augment = False
39+
args.feature_dir = None
40+
args.seed = 42
41+
args.warmup_epochs = 10
42+
args.sparsity_lambda = 0.0
43+
44+
device = "cuda" if torch.cuda.is_available() else "cpu"
45+
46+
genes_path = "global_genes.json"
47+
with open(genes_path, "r") as f:
48+
gene_list = json.load(f)[:1000]
49+
args.num_genes = len(gene_list)
50+
51+
final_ids = get_sample_ids(
52+
args.data_dir, precomputed=args.precomputed, backbone=args.backbone, max_samples=1
53+
)
54+
train_loader, _ = setup_dataloaders(args, final_ids, [])
55+
56+
model = SpatialTranscriptFormer(
57+
num_genes=args.num_genes,
58+
backbone_name=args.backbone,
59+
pretrained=False,
60+
token_dim=args.token_dim,
61+
n_heads=args.n_heads,
62+
n_layers=args.n_layers,
63+
num_pathways=50,
64+
use_spatial_pe=True,
65+
output_mode="counts",
66+
)
67+
68+
ckpt_path = os.path.join(args.output_dir, "latest_model_interaction.pth")
69+
if os.path.exists(ckpt_path):
70+
print("Loading", ckpt_path)
71+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
72+
model.load_state_dict(ckpt["model_state_dict"], strict=False)
73+
else:
74+
print("No ckpt found!")
75+
76+
model.to(device)
77+
model.eval()
78+
79+
with torch.no_grad():
80+
for batch in train_loader:
81+
feats, genes, coords, mask = [x.to(device) for x in batch]
82+
out = model(feats, rel_coords=coords, mask=mask, return_dense=True)
83+
preds = out
84+
85+
preds = torch.expm1(preds) if args.log_transform else preds
86+
targets = torch.expm1(genes) if args.log_transform else genes
87+
88+
patch_idx = None
89+
for i in range(mask.shape[1]):
90+
if not mask[0, i]:
91+
patch_idx = i
92+
break
93+
94+
with open(
95+
"C:/Users/wispy/.gemini/antigravity/brain/6a31ec6d-2f34-4f97-96b8-e437c2640219/model_output_sample.md",
96+
"w",
97+
) as f:
98+
f.write("# Model Output Sample (stf_tiny with simplifications)\n\n")
99+
if patch_idx is not None:
100+
f.write("### Target vs Prediction for a Single Valid Patch\n")
101+
f.write("Showing the first 20 genes (absolute expression counts).\n\n")
102+
103+
f.write("| Gene Index | Target Count (True) | Predicted Count |\n")
104+
f.write("|------------|----------------------|-----------------|\n")
105+
106+
t_vals = targets[0, patch_idx, :20].cpu().numpy()
107+
p_vals = preds[0, patch_idx, :20].cpu().numpy()
108+
109+
for i in range(20):
110+
f.write(f"| {i} | {t_vals[i]:.2f} | {p_vals[i]:.2f} |\n")
111+
112+
f.write("\n### Summary Statistics Across All Patches in Batch\n")
113+
f.write(f"- Target Mean: {targets[~mask].mean().item():.4f}\n")
114+
f.write(f"- Target Max: {targets[~mask].max().item():.4f}\n")
115+
f.write(f"- Pred Mean: {preds[~mask].mean().item():.4f}\n")
116+
f.write(f"- Pred Max: {preds[~mask].max().item():.4f}\n")
117+
f.write(f"- Pred Min: {preds[~mask].min().item():.4f}\n")
118+
else:
119+
f.write("No valid patches found in sample.\n")
120+
121+
print("Sample logic written to artifact.")
122+
break

scripts/migrate_logs_to_sqlite.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
import pandas as pd
3+
import sqlite3
4+
import argparse
5+
6+
7+
def migrate_csv_to_sqlite(run_dir):
8+
csv_path = os.path.join(run_dir, "training_log.csv")
9+
db_path = os.path.join(run_dir, "training_logs.sqlite")
10+
11+
if not os.path.exists(csv_path):
12+
print(f"No CSV found at {csv_path}")
13+
return
14+
15+
print(f"Migrating {csv_path} to {db_path}...")
16+
df = pd.read_csv(csv_path)
17+
18+
with sqlite3.connect(db_path) as conn:
19+
df.to_sql("metrics", conn, if_exists="replace", index=False)
20+
print("Done!")
21+
22+
23+
if __name__ == "__main__":
24+
parser = argparse.ArgumentParser()
25+
parser.add_argument(
26+
"--run-dir", type=str, required=True, help="Path to run directory"
27+
)
28+
args = parser.parse_args()
29+
migrate_csv_to_sqlite(args.run_dir)

scripts/predict_sample.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#!/usr/bin/env python
2+
import argparse
3+
import os
4+
import torch
5+
import json
6+
from spatial_transcript_former.visualization import run_inference_plot
7+
8+
9+
# Dummy class to hold loaded arguments
10+
class RunArgs:
11+
def __init__(self, **entries):
12+
self.__dict__.update(entries)
13+
14+
15+
def parse_args():
16+
parser = argparse.ArgumentParser("Predict sample pathways")
17+
parser.add_argument(
18+
"--sample-id",
19+
required=True,
20+
type=str,
21+
help="Sample ID to run inference on (e.g. TENX156)",
22+
)
23+
parser.add_argument(
24+
"--run-dir",
25+
required=True,
26+
type=str,
27+
help="Directory containing model weights and args.json",
28+
)
29+
parser.add_argument(
30+
"--output-dir", type=str, default=".", help="Where to save the output plot"
31+
)
32+
parser.add_argument(
33+
"--epoch", type=int, default=0, help="Epoch number to label the plot with"
34+
)
35+
return parser.parse_args()
36+
37+
38+
def main():
39+
cli_args = parse_args()
40+
41+
# Load args from run_dir
42+
args_path = os.path.join(cli_args.run_dir, "results_summary.json")
43+
if not os.path.exists(args_path):
44+
raise FileNotFoundError(f"Missing {args_path}")
45+
46+
with open(args_path, "r") as f:
47+
summary_dict = json.load(f)
48+
run_args_dict = summary_dict.get("config", {})
49+
50+
run_args = RunArgs(**run_args_dict)
51+
run_args.output_dir = cli_args.output_dir
52+
run_args.run_dir = cli_args.run_dir
53+
54+
# Optional arguments that might be missing from older args.json
55+
if not hasattr(run_args, "log_transform"):
56+
run_args.log_transform = False
57+
58+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59+
60+
# Re-initialize the model based on run_args
61+
if run_args.model == "baseline":
62+
from spatial_transcript_former.models import SpatialTranscriptFormer
63+
64+
model = SpatialTranscriptFormer(
65+
backbone=run_args.backbone,
66+
num_genes=run_args.num_genes,
67+
dropout=run_args.dropout,
68+
n_neighbors=run_args.n_neighbors,
69+
)
70+
elif run_args.model == "interaction":
71+
from spatial_transcript_former.models import SpatialTranscriptFormer
72+
73+
model = SpatialTranscriptFormer(
74+
num_genes=run_args.num_genes,
75+
backbone_name=run_args.backbone,
76+
pretrained=run_args.pretrained,
77+
token_dim=getattr(run_args, "token_dim", 384),
78+
n_heads=getattr(run_args, "n_heads", 6),
79+
n_layers=getattr(run_args, "n_layers", 4),
80+
num_pathways=getattr(run_args, "num_pathways", 0),
81+
use_spatial_pe=getattr(run_args, "use_spatial_pe", True),
82+
output_mode="zinb" if getattr(run_args, "loss", "") == "zinb" else "counts",
83+
interactions=getattr(run_args, "interactions", None),
84+
)
85+
else:
86+
raise ValueError(f"Unknown model type: {run_args.model}")
87+
88+
model.to(device)
89+
90+
# Note: we explicitly load the *best* model if it exists, otherwise the latest
91+
ckpt_path = os.path.join(cli_args.run_dir, f"best_model_{run_args.model}.pth")
92+
if not os.path.exists(ckpt_path):
93+
ckpt_path = os.path.join(cli_args.run_dir, f"latest_model_{run_args.model}.pth")
94+
95+
if os.path.exists(ckpt_path):
96+
print(f"Loading checkpoint from {ckpt_path}...")
97+
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
98+
if "model_state_dict" in checkpoint:
99+
model.load_state_dict(checkpoint["model_state_dict"])
100+
else:
101+
model.load_state_dict(checkpoint)
102+
else:
103+
print(
104+
f"Warning: No checkpoint found in {cli_args.run_dir}. Using untrained model."
105+
)
106+
107+
print(f"Running inference for sample {cli_args.sample_id}...")
108+
run_inference_plot(model, run_args, cli_args.sample_id, cli_args.epoch, device)
109+
110+
111+
if __name__ == "__main__":
112+
main()

0 commit comments

Comments
 (0)