Skip to content

feat(gsp-rl): JEPA-MVP encoder + Phase 4 metric/config additions#4

Merged
jdbloom merged 5 commits into
mainfrom
feat/jepa-mvp
May 5, 2026
Merged

feat(gsp-rl): JEPA-MVP encoder + Phase 4 metric/config additions#4
jdbloom merged 5 commits into
mainfrom
feat/jepa-mvp

Conversation

@jdbloom

@jdbloom jdbloom commented May 5, 2026

Copy link
Copy Markdown
Owner

Summary

Catches GSP-RL main up to where the production code has actually been running. 5 commits, all already running successfully on both Mac + Ubuntu against stelaris main:

  • JEPA-MVP (`2f3a58d`, `37b42af`): latent-space encoder head + JEPA-aware save/load_model
  • GSP_HEAD_FROZEN flag (`7c14cc1`): H-phase5-4 frozen-head control
  • gsp_loss_step_corr per-batch diagnostic metric (`852faff`)
  • GSP_HEAD_LR independent head training config (`e44aa96`)

This PR is the GSP-RL counterpart to stelaris NESTLab#19 + RL-CT #3 — closing the last drift gap that the new PR-before-run gate caught immediately on first launch (smoke j931 on Mac + j864 on Ubuntu both refused to run with this branch unmerged).

Test plan

  • Already running successfully in production for ~3 weeks against stelaris feat/e2e-gsp-training
  • All Phase 5 + JEPA-MVP work + h7-schema-cmp + h8-action-range training cells used this code

🤖 Generated with Claude Code

Joshua Bloom and others added 5 commits April 25, 2026 15:32
…head training

Add GSP_HEAD_LR to Hyperparameters (default: config['LR'], so existing batches
are bit-for-bit identical). The GSP prediction head's Adam optimizer now uses
gsp_head_lr while the main action-network optimizer continues to use self.lr.

- learning_aids.py: self.gsp_head_lr = config.get('GSP_HEAD_LR', self.lr)
- actor.py build_DDPG_gsp(): 'lr': getattr(self, 'gsp_head_lr', self.lr)
  Only the GSP actor/prediction head uses the independent rate; the GSP critic
  and all trunk/action-network parameters remain at self.lr.

The GSP head already had its own optimizer (DDPGActorNetwork.optimizer, built
in build_DDPG_gsp) independent from the main policy optimizer. This change
exposes the LR of that existing per-head optimizer via config.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…sp_mse

learn_gsp_mse() now returns (loss_float, batch_corr_float) instead of a plain
float. batch_corr is the Pearson correlation between the fresh forward-pass
predictions that produced the MSE loss and the replay-buffer labels for that
same batch — computed with T.no_grad()/detach(), zero impact on gradient graph.

This is intentionally distinct from gsp_pred_target_corr (the existing
actor-input-path metric in hdf5_logger), which measures through the 1-timestep-
lagged stored predictions used by the actor. Comparing the two will reveal
whether the head IS learning on the loss path but the actor-input measurement
is broken, or whether the head genuinely fails to learn.

actor.py: capture batch_corr from the tuple return and accumulate into
last_gsp_loss_step_corr_samples list (per-episode, cleared by Main.py).
test_cchain.py: update assertion from isinstance(loss, float) to tuple unpack.
test_gsp_loss_step_corr.py: 6 new unit tests covering return type, finiteness,
NaN handling, accumulation via learn(), positive corr on linear task, and
gradient-graph isolation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Skip self.learn_gsp() call in Actor.learn() loop when GSP_HEAD_FROZEN is True.
Head stays at random init for the entire run. Tests whether reward-shaping wins
require head learning at all, or whether reward density alone explains the
H-phase5-2 shape_med PASS.

Default False preserves all prior behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Add JEPAEncoder (2-layer MLP + LayerNorm, raw linear output) and
  JEPAPredictor (1-layer MLP) in gsp_rl/src/networks/jepa.py
- When GSP_JEPA_ENABLED=True in Actor.__init__: instantiate online encoder,
  EMA-frozen target encoder (deep copy, requires_grad=False), predictor,
  and a joint optimizer (Adam, lr=GSP_HEAD_LR); skip the legacy build_gsp_network()
  path. network_input_size grows by GSP_ENCODER_DIM (default 32) instead of 1.
- Add Hyperparameters attributes: gsp_jepa_enabled, gsp_encoder_dim (32),
  gsp_encoder_ema_tau (0.995). Default False preserves all existing runs.
- Add learn_gsp_jepa(): samples (state_t, future_state) from JEPA replay,
  computes latent MSE + optional VICReg, backward, EMA update. Returns
  (loss_float, {var, rank, pred_mse}) latent stats dict.
- Add _update_jepa_target_encoder() EMA helper.
- Dispatch JEPA learning scheme in learn_gsp(); store latent stats on
  self.last_gsp_jepa_stats for Main.py consumption.
- 8 new tests in tests/test_jepa.py (shapes + EMA math). Full suite: 396 passed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
JEPA path skips legacy build_gsp_network, so gsp_networks['actor'] doesn't
exist. Old save_model assumed it was always present → KeyError at first
per-N-episode model save. Add JEPA branch: torch.save the encoder_online
+ encoder_target + predictor state_dicts to a sibling _jepa.pt file.
Symmetric load_model recovers them.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jdbloom jdbloom merged commit 578ba44 into main May 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant