diff --git a/gsp_rl/src/actors/actor.py b/gsp_rl/src/actors/actor.py index f68da8c..88f7c20 100644 --- a/gsp_rl/src/actors/actor.py +++ b/gsp_rl/src/actors/actor.py @@ -23,6 +23,7 @@ AttentionSequenceReplayBuffer ) from gsp_rl.src.actors.learning_aids import NetworkAids +from gsp_rl.src.networks.jepa import JEPAEncoder, JEPAPredictor class Actor(NetworkAids): @@ -132,10 +133,15 @@ def __init__( self.network_input_size = self.input_size if self.gsp: - # For multi-dim GSP output (gsp_output_kind != delta_theta_1d), - # the actor's augmented obs grows by the full output vector width, - # not just 1. agent.py's make_agent_state handles the concatenation. - self.network_input_size += self.gsp_network_output + if getattr(self, 'gsp_jepa_enabled', False): + # JEPA path: actor receives the full latent vector (encoder_dim) + # instead of the legacy scalar/K-dim GSP prediction. + self.network_input_size += getattr(self, 'gsp_encoder_dim', 32) + else: + # For multi-dim GSP output (gsp_output_kind != delta_theta_1d), + # the actor's augmented obs grows by the full output vector width, + # not just 1. agent.py's make_agent_state handles the concatenation. + self.network_input_size += self.gsp_network_output if self.attention_gsp: self.attention_observation = [[0 for _ in range(self.gsp_network_input)] for _ in range(self.gsp_sequence_length)] elif self.recurrent_gsp: @@ -143,10 +149,51 @@ def __init__( self.build_networks(network) self.gsp_networks = None + + # JEPA latent-space encoder path. When enabled, skip the legacy + # DDPG-based GSP build and instantiate encoder + predictor instead. + # The target encoder is an EMA copy of the online encoder (frozen). + self.gsp_encoder_online = None + self.gsp_encoder_target = None + self.gsp_predictor = None + self._jepa_online_optimizer = None + self._jepa_predictor_optimizer = None + if gsp: - if attention: - self.build_gsp_network('attention') - self.build_gsp_network('DDPG') + if getattr(self, 'gsp_jepa_enabled', False): + import copy + _enc_dim = getattr(self, 'gsp_encoder_dim', 32) + _head_lr = getattr(self, 'gsp_head_lr', self.lr) + self.gsp_encoder_online = JEPAEncoder( + input_dim=self.gsp_network_input, + latent_dim=_enc_dim, + ) + # Target encoder: EMA copy — weights frozen, no gradient + self.gsp_encoder_target = copy.deepcopy(self.gsp_encoder_online) + for param in self.gsp_encoder_target.parameters(): + param.requires_grad = False + self.gsp_predictor = JEPAPredictor(latent_dim=_enc_dim) + # Optimizers: online encoder + predictor share one optimizer + self._jepa_online_optimizer = T.optim.Adam( + list(self.gsp_encoder_online.parameters()) + + list(self.gsp_predictor.parameters()), + lr=_head_lr, + ) + # JEPA requires a replay buffer for (state_t, state_{t+k}) pairs. + # We reuse the standard ReplayBuffer; the future state is stored + # in the 'action' slot by convention (matches future_prox label path). + self.gsp_networks = {} + self.gsp_networks['learning_scheme'] = 'JEPA' + self.gsp_networks['replay'] = ReplayBuffer( + self.mem_size, self.gsp_network_input, + self.gsp_network_input, 'Continuous' + ) + self.gsp_networks['learn_step_counter'] = 0 + self.gsp_networks['output_size'] = _enc_dim + else: + if attention: + self.build_gsp_network('attention') + self.build_gsp_network('DDPG') # Information-collapse diagnostic: last GSP learner training loss. # NOTE: this is the loss returned by the GSP learner's inner learn step, which means: @@ -164,6 +211,17 @@ def __init__( self.last_gsp_loss: float | None = None # Populated by learn() when the e2e path fires; reset to None each learn() call. self.last_e2e_diagnostics = None + # Phase 4 loss-step correlation diagnostic. Accumulates one float per + # GSP learn step (the Pearson corr between fresh forward-pass preds and + # replay-buffer labels for that batch). Collected by Main.py at episode + # end to produce mean/std attrs in HDF5. Main.py is responsible for + # clearing this list after consuming it (not reset per-tick like + # last_gsp_loss, because it accumulates across an episode). + self.last_gsp_loss_step_corr_samples: list = [] + # JEPA latent stats from the most recent learn_gsp_jepa call. Dict with + # keys {var, rank, pred_mse}. Reset to None each learn() tick (like + # last_gsp_loss). Main.py reads this to call hdf5_writer.record_jepa_*. + self.last_gsp_jepa_stats: dict | None = None def build_networks(self, learning_scheme): if learning_scheme == 'None': @@ -337,7 +395,13 @@ def build_DDPG_gsp(self): 'id':self.id, 'input_size':self.gsp_network_input, 'output_size':self.gsp_network_output, - 'lr': self.lr, + # Phase 4: use gsp_head_lr (independent of trunk/actor LR). + # Default: same as self.lr (from config['LR']), so existing batches are + # bit-for-bit identical. Override via GSP_HEAD_LR in the experiment YAML. + # The GSP critic LR intentionally stays at self.lr — only the actor/head + # that produces predictions (and is trained via supervised MSE) gets the + # independent rate. + 'lr': getattr(self, 'gsp_head_lr', self.lr), 'min_max_action':self.min_max_action, # Task 0 ablation knobs — defaults preserve legacy behavior exactly. # Only the GSP-head actor network receives these; the main policy actor @@ -480,15 +544,22 @@ def learn(self): # Reset per-tick diagnostic signals. None means "no step ran this tick". self.last_gsp_loss = None self.last_e2e_diagnostics = None + self.last_gsp_jepa_stats = None # TODO Not sure why we have n_agents*batch_size + batch_size if self.networks['replay'].mem_ctr < self.batch_size: # (self.n_agents*self.batch_size + self.batch_size): return if self.gsp: - if self.networks['learn_step_counter'] % self.gsp_learning_offset == 0: - #print('[DEBUG] Learning Attention', self.networks['learn_step_counter']) - self.learn_gsp() + # H-phase5-4: when GSP_HEAD_FROZEN is true, skip the GSP head's + # optimizer step entirely. Head stays at random init for the + # entire run. Tests whether reward-shaping wins require head + # learning at all, or whether reward density alone explains the + # effect. Default false preserves all prior behavior. + if not getattr(self, 'gsp_head_frozen', False): + if self.networks['learn_step_counter'] % self.gsp_learning_offset == 0: + #print('[DEBUG] Learning Attention', self.networks['learn_step_counter']) + self.learn_gsp() if self.networks['learning_scheme'] == 'DDQN' and getattr(self, 'gsp_e2e_enabled', False) and self.gsp: self.replace_target_network() @@ -529,18 +600,36 @@ def learn_gsp(self): # for root cause analysis. loss = None scheme = self.gsp_networks['learning_scheme'] - if scheme == 'attention': + if scheme == 'JEPA': + loss = self.learn_gsp_jepa(self.gsp_networks) + elif scheme == 'attention': loss = self.learn_attention(self.gsp_networks) elif scheme == 'RDDPG': loss = self.learn_gsp_mse(self.gsp_networks, recurrent=True) elif scheme in {'DDPG', 'TD3'}: loss = self.learn_gsp_mse(self.gsp_networks, recurrent=False) if loss is not None: - # Keep the tuple-skip guard for safety in case learn_attention's - # return type ever changes; learn_gsp_mse returns a plain float. + # learn_gsp_mse returns (loss_float, batch_corr_float). + # learn_gsp_jepa returns (loss_float, latent_stats_dict). + # learn_attention returns a plain float — keep the tuple dispatch. if isinstance(loss, tuple): - return - self.last_gsp_loss = float(loss) + loss_val = loss[0] + self.last_gsp_loss = float(loss_val) + extra = loss[1] + if isinstance(extra, dict): + # JEPA path: store latent stats for Main.py to record. + self.last_gsp_jepa_stats = extra + elif isinstance(extra, float): + batch_corr = extra + # Accumulate per-batch loss-step correlations across all GSP learn + # steps within this episode. Main.py reads + # last_gsp_loss_step_corr_samples at episode end, computes + # mean/std, and passes them to hdf5_writer. Attribute is + # initialised in __init__ and cleared by Main.py at episode end. + if not math.isnan(batch_corr): + self.last_gsp_loss_step_corr_samples.append(batch_corr) + else: + self.last_gsp_loss = float(loss) def store_agent_transition(self, s, a, r, s_, d, gsp_obs=None, gsp_label=None): self.store_transition(s, a, r, s_, d, self.networks, gsp_obs=gsp_obs, gsp_label=gsp_label) @@ -912,6 +1001,15 @@ def save_model(self, path): if self.attention_gsp: if self.gsp_networks['learning_scheme'] == 'attention': self.gsp_networks['attention'].save_checkpoint(path) + elif self.gsp and getattr(self, 'gsp_jepa_enabled', False): + # JEPA path: save encoder_online + predictor + target_encoder via torch.save + import torch + jepa_state = { + 'encoder_online': self.gsp_encoder_online.state_dict(), + 'encoder_target': self.gsp_encoder_target.state_dict(), + 'predictor': self.gsp_predictor.state_dict(), + } + torch.save(jepa_state, f"{path}_jepa.pt") elif self.gsp: self.gsp_networks['actor'].save_checkpoint(path, self.gsp) self.gsp_networks['target_actor'].save_checkpoint(path, self.gsp) @@ -947,6 +1045,14 @@ def load_model(self, path): if self.attention_gsp: if self.gsp_networks['learning_scheme'] == 'attention': self.gsp_networks['attention'].load_checkpoint(path) + elif self.gsp and getattr(self, 'gsp_jepa_enabled', False): + import os, torch + jepa_path = f"{path}_jepa.pt" + if os.path.exists(jepa_path): + state = torch.load(jepa_path, map_location='cpu') + self.gsp_encoder_online.load_state_dict(state['encoder_online']) + self.gsp_encoder_target.load_state_dict(state['encoder_target']) + self.gsp_predictor.load_state_dict(state['predictor']) elif self.gsp: self.gsp_networks['actor'].load_checkpoint(path, self.gsp) self.gsp_networks['target_actor'].load_checkpoint(path, self.gsp) diff --git a/gsp_rl/src/actors/learning_aids.py b/gsp_rl/src/actors/learning_aids.py index 26ca8f5..aa39717 100644 --- a/gsp_rl/src/actors/learning_aids.py +++ b/gsp_rl/src/actors/learning_aids.py @@ -177,6 +177,15 @@ def __init__(self, config): self.beta = config['BETA'] self.lr = config['LR'] + # Phase 4 — independent GSP head learning rate. + # Default: same value as the trunk/actor LR (config['LR']), preserving + # exact legacy behavior for all existing batches. When set to a different + # value, the GSP prediction head's Adam optimizer uses gsp_head_lr while + # the main action-network optimizer continues to use self.lr. + # Only affects the GSP actor/head network; the GSP critic and target + # networks are unchanged (they remain tied to self.lr). + self.gsp_head_lr = float(config.get('GSP_HEAD_LR', self.lr)) + self.epsilon = config['EPSILON'] self.eps_min = config['EPS_MIN'] self.eps_dec = config['EPS_DEC'] @@ -317,6 +326,19 @@ def __init__(self, config): self.update_actor_iter = config['UPDATE_ACTOR_ITER'] self.warmup = config['WARMUP'] self.time_step = 0 + # H-phase5-4: when True, skip the GSP head's optimizer step entirely. + # Head stays at random init for the run. Default False preserves all + # prior behavior. Read in actor.py:502 in the learn() loop. + self.gsp_head_frozen = bool(config.get('GSP_HEAD_FROZEN', False)) + + # JEPA (Joint Embedding Predictive Architecture) latent-space head. + # When enabled, the legacy scalar future_prox prediction is replaced + # by: online encoder → predictor → latent MSE against EMA target encoder. + # The actor receives the 32-d encoder output instead of the 1-d gsp_pred. + # Default False — all existing runs are unaffected. + self.gsp_jepa_enabled = bool(config.get('GSP_JEPA_ENABLED', False)) + self.gsp_encoder_dim = int(config.get('GSP_ENCODER_DIM', 32)) + self.gsp_encoder_ema_tau = float(config.get('GSP_ENCODER_EMA_TAU', 0.995)) class NetworkAids(Hyperparameters): """Network factory, learning algorithms, action selection, and memory management. @@ -972,7 +994,139 @@ def learn_gsp_mse(self, networks, recurrent: bool = False): networks['actor'].optimizer.step() networks['learn_step_counter'] += 1 - return loss.item() + + # Phase 4 — loss-step correlation diagnostic. + # Compute Pearson correlation between the FRESH forward-pass predictions + # (the same preds that produced the MSE loss) and the replay-buffer labels. + # This is intentionally different from gsp_pred_target_corr in hdf5_logger, + # which accumulates actor-input-path predictions over a full episode (a + # different code path with a 1-timestep lag). Computing per-batch here and + # aggregating in the caller lets us compare "is the loss-path head actually + # learning?" vs "is the actor-input path measurement broken?" + # + # Safety contract: + # - Uses T.no_grad() / detach — zero gradient graph impact. + # - NaN/zero-variance guard: returns float("nan") when undefined. + # - Shape agnostic: flattens both arrays before corrcoef. + # - Recurrent path: preds/labels_shaped not available in that scope, + # so we skip and return nan for consistency. + batch_corr: float = float("nan") + if not recurrent: + with T.no_grad(): + _pred_np = preds.detach().cpu().numpy().flatten() + _lbl_np = labels_shaped.detach().cpu().numpy().flatten() + if _pred_np.size > 1: + _STD_TOL = 1e-12 + _p_std = float(np.nanstd(_pred_np)) + _l_std = float(np.nanstd(_lbl_np)) + if _p_std > _STD_TOL and _l_std > _STD_TOL: + _mask = np.isfinite(_pred_np) & np.isfinite(_lbl_np) + if _mask.sum() > 1: + batch_corr = float(np.corrcoef(_pred_np[_mask], _lbl_np[_mask])[0, 1]) + + return loss.item(), batch_corr + + def _update_jepa_target_encoder(self, tau: float) -> None: + """EMA update: target_p ← tau * target_p + (1 - tau) * online_p. + + Args: + tau: EMA decay coefficient (e.g. 0.995). Higher = slower update. + """ + with T.no_grad(): + for online_p, target_p in zip( + self.gsp_encoder_online.parameters(), + self.gsp_encoder_target.parameters(), + ): + target_p.data.mul_(tau).add_(online_p.data, alpha=1.0 - tau) + + def learn_gsp_jepa(self, networks: dict): + """Train the JEPA latent-space GSP head. + + Samples (state_t, state_{t+k}) pairs from the JEPA replay buffer + (state_t in the 'state' slot, state_{t+k} in the 'action' slot by + convention). Computes: + + z_t = encoder_online(state_t) # online encoding + z_pred = predictor(z_t) # predicted future latent + z_target = encoder_target(state_{t+k}).detach() # EMA target + + loss_pred = MSE(z_pred, z_target) # latent prediction loss + + Optional VICReg variance + covariance penalties on z_t are added + when self.gsp_vicreg_enabled is True (reusing existing helpers). + + After backward + optimizer step, the target encoder is updated via EMA. + + Returns: + Tuple (loss_float, latent_stats_dict) where latent_stats_dict has: + {var: float, rank: float, pred_mse: float} + """ + if networks['replay'].mem_ctr < self.gsp_batch_size: + return None + + vicreg_enabled = getattr(self, 'gsp_vicreg_enabled', False) + tau = float(getattr(self, 'gsp_encoder_ema_tau', 0.995)) + enc_device = self.gsp_encoder_online.device + + # Sample directly from the JEPA replay buffer rather than going through + # sample_memory(), which requires a 'actor' or 'q_eval' key in networks + # to determine device. JEPA networks dict has neither — device comes from + # the encoder module itself. + result = networks['replay'].sample_buffer(self.gsp_batch_size) + raw_states, raw_future, _, _, _ = result[0], result[1], result[2], result[3], result[4] + states = T.tensor(raw_states, dtype=T.float32).to(enc_device) + # future_states: stored in the 'action' slot by convention (state_{t+k}) + future_states = T.tensor(raw_future, dtype=T.float32).to(enc_device) + + # Forward through online encoder + predictor + z_t = self.gsp_encoder_online(states) + z_pred = self.gsp_predictor(z_t) + + # Target: forward through frozen target encoder + with T.no_grad(): + z_target = self.gsp_encoder_target(future_states) + + loss_pred = F.mse_loss(z_pred, z_target) + loss = loss_pred + + # Optional VICReg on online encoder output z_t + if vicreg_enabled: + var_coef = float(getattr(self, 'gsp_vicreg_var_coef', 1.0)) + cov_coef = float(getattr(self, 'gsp_vicreg_cov_coef', 0.04)) + # target_std: 1.0 (standard VICReg default) — latent lives in + # unbounded linear space so label-std normalization is not needed. + var_loss = vicreg_variance_loss(z_t, target_std=1.0) + cov_loss = vicreg_covariance_loss(z_t) + loss = loss_pred + var_coef * var_loss + cov_coef * cov_loss + + self._jepa_online_optimizer.zero_grad() + loss.backward() + _check_nan(loss, f"JEPA loss at step {networks['learn_step_counter']}") + self._jepa_online_optimizer.step() + + # EMA update of target encoder + self._update_jepa_target_encoder(tau) + + networks['learn_step_counter'] += 1 + + # Compute latent statistics (no grad) + with T.no_grad(): + latent_var = float(z_t.var(dim=0).mean().item()) + # Approximate rank: number of singular values above 1% of max + z_cpu = z_t.detach().cpu() + try: + sv = T.linalg.svdvals(z_cpu) + rank = float((sv > sv[0] * 0.01).sum().item()) + except Exception: + rank = float("nan") + pred_mse = float(loss_pred.item()) + + latent_stats = { + 'var': latent_var, + 'rank': rank, + 'pred_mse': pred_mse, + } + return loss.item(), latent_stats def decrement_epsilon(self): self.epsilon = max(self.epsilon-self.eps_dec, self.eps_min) diff --git a/gsp_rl/src/networks/__init__.py b/gsp_rl/src/networks/__init__.py index f55edd6..219975c 100644 --- a/gsp_rl/src/networks/__init__.py +++ b/gsp_rl/src/networks/__init__.py @@ -24,4 +24,5 @@ def get_device(recurrent: bool = False) -> T.device: from .rddpg import RDDPGActorNetwork, RDDPGCriticNetwork from .td3 import TD3ActorNetwork, TD3CriticNetwork from .lstm import EnvironmentEncoder -from .self_attention import AttentionEncoder \ No newline at end of file +from .self_attention import AttentionEncoder +from .jepa import JEPAEncoder, JEPAPredictor \ No newline at end of file diff --git a/gsp_rl/src/networks/jepa.py b/gsp_rl/src/networks/jepa.py new file mode 100644 index 0000000..3b74058 --- /dev/null +++ b/gsp_rl/src/networks/jepa.py @@ -0,0 +1,100 @@ +"""JEPA (Joint Embedding Predictive Architecture) modules for the GSP head. + +Provides two nn.Module classes used by the JEPA training path in actor.py: + +- JEPAEncoder: maps raw GSP head input (state_t or state_t+k) to a latent + vector of shape (batch, latent_dim). Two-layer MLP, raw linear output (no + tanh), with LayerNorm on the hidden representation. + +- JEPAPredictor: maps online encoder latent z_t → predicted future latent + z_{t+k}. One-layer MLP. Output is in the same latent space as the target + encoder (used for MSE prediction loss). + +Design notes: +- No tanh on the encoder output — the latent lives in an unbounded linear + space so that VICReg variance/covariance losses operate without saturation. +- LayerNorm on the hidden (encoder fc1 output) stabilizes training but is + intentionally omitted from the final linear projection (matching the + VICReg expander design in Bardes et al. ICLR 2022). +- The target encoder is maintained as an EMA copy of the online encoder in + actor.py; this file only defines the shared architecture. + +See: docs/superpowers/specs/2026-04-16-jepa-mvp.md for the design rationale. +""" +import torch as T +import torch.nn as nn + + +class JEPAEncoder(nn.Module): + """Two-layer MLP encoder: state → latent. + + Architecture: + fc1 (input_dim → hidden) → LayerNorm → ReLU → fc2 (hidden → latent_dim) + + The final projection is a raw linear layer — no activation, no LayerNorm. + This preserves gradient flow into the latent space and lets VICReg + variance/covariance regularization operate on unbounded representations. + + Args: + input_dim: Dimensionality of the GSP head input (gsp_network_input). + latent_dim: Output latent dimensionality. Default 32 (GSP_ENCODER_DIM). + hidden: Hidden layer width. Default 128. + """ + + def __init__(self, input_dim: int, latent_dim: int = 32, hidden: int = 128): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden) + self.ln1 = nn.LayerNorm(hidden) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden, latent_dim) + + # Determine device from environment (mirrors existing network convention) + self.device = T.device("cuda:0" if T.cuda.is_available() else "cpu") + self.to(self.device) + + def forward(self, x: T.Tensor) -> T.Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (batch, input_dim). + + Returns: + Latent tensor of shape (batch, latent_dim). + """ + h = self.relu(self.ln1(self.fc1(x))) + return self.fc2(h) + + +class JEPAPredictor(nn.Module): + """One-layer MLP predictor: latent_t → predicted latent_{t+k}. + + Maps the online encoder's latent representation z_t to a prediction of + the target encoder's latent representation z_{t+k}. The MSE loss between + predictor output and the (detached) target latent drives supervised + learning in the latent space. + + Args: + latent_dim: Both input and output dimensionality (latent space). + hidden: Hidden layer width. Default 64. + """ + + def __init__(self, latent_dim: int = 32, hidden: int = 64): + super().__init__() + self.fc1 = nn.Linear(latent_dim, hidden) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden, latent_dim) + + self.device = T.device("cuda:0" if T.cuda.is_available() else "cpu") + self.to(self.device) + + def forward(self, z: T.Tensor) -> T.Tensor: + """Forward pass. + + Args: + z: Latent tensor of shape (batch, latent_dim). + + Returns: + Predicted future latent of shape (batch, latent_dim). + """ + h = self.relu(self.fc1(z)) + return self.fc2(h) diff --git a/tests/test_actor/test_cchain.py b/tests/test_actor/test_cchain.py index 955ba29..fcf6824 100644 --- a/tests/test_actor/test_cchain.py +++ b/tests/test_actor/test_cchain.py @@ -120,11 +120,18 @@ def counting_step(*args, **kwargs): actor.gsp_networks['actor'].optimizer.step = counting_step - loss_val = actor.learn_gsp_mse(actor.gsp_networks) + result = actor.learn_gsp_mse(actor.gsp_networks) - assert loss_val is not None, "learn_gsp_mse must return a loss" - assert isinstance(loss_val, float), f"Expected float, got {type(loss_val)}" + # learn_gsp_mse now returns (loss_float, batch_corr_float) — a tuple. + assert result is not None, "learn_gsp_mse must return a result" + assert isinstance(result, tuple) and len(result) == 2, ( + f"Expected (loss_float, batch_corr_float) tuple, got {type(result)}" + ) + loss_val, batch_corr = result + assert isinstance(loss_val, float), f"First element must be float loss, got {type(loss_val)}" assert np.isfinite(loss_val), f"Loss must be finite, got {loss_val}" + # batch_corr is either a finite float or nan (undefined correlation). + assert isinstance(batch_corr, float), f"Second element must be float corr, got {type(batch_corr)}" # With lambda=0.0, exactly one optimizer step: the MSE step. assert len(step_calls) == 1, ( f"With lambda=0.0 exactly 1 optimizer step should occur, got {len(step_calls)}" diff --git a/tests/test_actor/test_gsp_loss_step_corr.py b/tests/test_actor/test_gsp_loss_step_corr.py new file mode 100644 index 0000000..b9cebbc --- /dev/null +++ b/tests/test_actor/test_gsp_loss_step_corr.py @@ -0,0 +1,264 @@ +"""Unit tests for the Phase 4 gsp_loss_step_corr diagnostic metric. + +learn_gsp_mse() now returns (loss_float, batch_corr_float) instead of a plain +float. batch_corr is the Pearson correlation between the fresh forward-pass +predictions used to compute the MSE loss and the replay-buffer labels for that +same batch. + +Verifies: +1. Return type is (float, float) tuple — not a plain float. +2. loss element is finite. +3. batch_corr element is either finite (normal case) or nan (undefined corr). +4. learn_gsp() (called via Actor.learn()) populates + last_gsp_loss_step_corr_samples with at least one finite float when the + replay buffer has enough samples. +5. learn_gsp_mse on a perfectly learnable linear task returns a positive + correlation (head corr improves over training). +6. Gradient graph is unaffected: loss.backward() + optimizer.step() paths + are identical to pre-change behaviour (no new requires_grad tensors leak + from the correlation computation). +""" + +import math + +import numpy as np +import pytest +import torch + +from gsp_rl.src.actors.actor import Actor + + +# --------------------------------------------------------------------------- +# Shared helpers (mirrors test_gsp_direct_mse.py conventions) +# --------------------------------------------------------------------------- + +BASE_CONFIG = { + "GAMMA": 0.99, + "TAU": 0.005, + "ALPHA": 0.001, + "BETA": 0.002, + "LR": 0.001, + "EPSILON": 0.0, + "EPS_MIN": 0.0, + "EPS_DEC": 0.0, + "BATCH_SIZE": 16, + "MEM_SIZE": 1000, + "REPLACE_TARGET_COUNTER": 10, + "NOISE": 0.0, + "UPDATE_ACTOR_ITER": 1, + "WARMUP": 0, + "GSP_LEARNING_FREQUENCY": 1, + "GSP_BATCH_SIZE": 16, +} + +INPUT_SIZE = 8 +OUTPUT_SIZE = 4 +GSP_INPUT_SIZE = 6 +GSP_OUTPUT_SIZE = 1 + + +def make_gsp_actor(extra_config=None): + cfg = dict(BASE_CONFIG) + if extra_config: + cfg.update(extra_config) + return Actor( + id=1, + config=cfg, + network="DDPG", + input_size=INPUT_SIZE, + output_size=OUTPUT_SIZE, + min_max_action=1, + meta_param_size=1, + gsp=True, + gsp_input_size=GSP_INPUT_SIZE, + gsp_output_size=GSP_OUTPUT_SIZE, + ) + + +def _fill_gsp_buffer(actor, n=100, seed=0): + """Fill the GSP replay buffer with (state, label) pairs.""" + rng = np.random.default_rng(seed) + states_list = [] + labels_list = [] + for _ in range(n): + s = rng.standard_normal(GSP_INPUT_SIZE).astype(np.float32) + label = np.float32(rng.standard_normal()) + actor.store_gsp_transition(s, label, 0.0, np.zeros_like(s), False) + states_list.append(s) + labels_list.append(float(label)) + return np.array(states_list), np.array(labels_list) + + +def _fill_primary_buffer(actor, n=100, seed=99): + """Fill the primary replay buffer so learn() doesn't short-circuit.""" + rng = np.random.default_rng(seed) + for _ in range(n): + s = rng.random(actor.network_input_size).astype(np.float32) + s_ = rng.random(actor.network_input_size).astype(np.float32) + a = actor.choose_action(s, actor.networks, test=True) + actor.store_transition(s, a, 0.0, s_, False, actor.networks) + + +def _fill_gsp_buffer_linear(actor, n=400, seed=0): + """Fill GSP buffer with a deterministic linear label: label = s[0] * 0.5.""" + rng = np.random.default_rng(seed) + states, labels = [], [] + for _ in range(n): + s = rng.standard_normal(GSP_INPUT_SIZE).astype(np.float32) + label = np.float32(s[0] * 0.5) + actor.store_gsp_transition(s, label, 0.0, np.zeros_like(s), False) + states.append(s) + labels.append(float(label)) + return np.array(states), np.array(labels) + + +# --------------------------------------------------------------------------- +# Test 1: return type is (float, float) tuple +# --------------------------------------------------------------------------- + +def test_learn_gsp_mse_returns_tuple(): + """learn_gsp_mse must return a 2-tuple (loss_float, batch_corr_float).""" + torch.manual_seed(0) + np.random.seed(0) + actor = make_gsp_actor() + _fill_gsp_buffer(actor, seed=0) + _fill_primary_buffer(actor) + + result = actor.learn_gsp_mse(actor.gsp_networks) + + assert isinstance(result, tuple), f"Expected tuple, got {type(result)}" + assert len(result) == 2, f"Expected 2-element tuple, got length {len(result)}" + loss_val, batch_corr = result + assert isinstance(loss_val, float), ( + f"First element (loss) must be float, got {type(loss_val)}" + ) + assert isinstance(batch_corr, float), ( + f"Second element (corr) must be float, got {type(batch_corr)}" + ) + + +# --------------------------------------------------------------------------- +# Test 2: loss element is finite +# --------------------------------------------------------------------------- + +def test_learn_gsp_mse_loss_is_finite(): + """Loss element of the returned tuple must be finite.""" + torch.manual_seed(1) + np.random.seed(1) + actor = make_gsp_actor() + _fill_gsp_buffer(actor, seed=1) + _fill_primary_buffer(actor) + + loss_val, _ = actor.learn_gsp_mse(actor.gsp_networks) + assert math.isfinite(loss_val), f"Loss must be finite, got {loss_val}" + + +# --------------------------------------------------------------------------- +# Test 3: batch_corr is finite or nan — never inf +# --------------------------------------------------------------------------- + +def test_learn_gsp_mse_batch_corr_not_inf(): + """batch_corr must be a finite float or nan — never +/-inf.""" + torch.manual_seed(2) + np.random.seed(2) + actor = make_gsp_actor() + _fill_gsp_buffer(actor, seed=2) + _fill_primary_buffer(actor) + + _, batch_corr = actor.learn_gsp_mse(actor.gsp_networks) + assert not math.isinf(batch_corr), ( + f"batch_corr must not be inf, got {batch_corr}" + ) + + +# --------------------------------------------------------------------------- +# Test 4: learn_gsp() via Actor.learn() accumulates samples +# --------------------------------------------------------------------------- + +def test_learn_populates_loss_step_corr_samples(): + """After Actor.learn(), last_gsp_loss_step_corr_samples has >= 1 entry.""" + torch.manual_seed(3) + np.random.seed(3) + actor = make_gsp_actor() + _fill_gsp_buffer(actor, seed=3) + _fill_primary_buffer(actor) + + actor.learn() + + samples = getattr(actor, "last_gsp_loss_step_corr_samples", None) + assert samples is not None, "last_gsp_loss_step_corr_samples must exist on Actor" + assert len(samples) >= 1, ( + f"Expected at least 1 sample after learn(), got {len(samples)}" + ) + for s in samples: + assert isinstance(s, float), f"Each sample must be float, got {type(s)}" + assert math.isfinite(s), f"Each sample must be finite, got {s}" + + +# --------------------------------------------------------------------------- +# Test 5: correlation is positive on a learnable linear task after training +# --------------------------------------------------------------------------- + +def test_batch_corr_positive_after_training_on_linear_task(): + """After 100 learn steps on a linear label task, mean batch_corr > 0. + + This confirms the loss-step path is measuring real learning signal, not + returning garbage or zero. + """ + torch.manual_seed(42) + np.random.seed(42) + actor = make_gsp_actor() + _fill_gsp_buffer_linear(actor, n=400, seed=42) + _fill_primary_buffer(actor) + + corr_values = [] + for _ in range(100): + _, batch_corr = actor.learn_gsp_mse(actor.gsp_networks) + if math.isfinite(batch_corr): + corr_values.append(batch_corr) + + assert len(corr_values) > 0, "No finite batch_corr values collected" + mean_corr = float(np.mean(corr_values)) + assert mean_corr > 0.0, ( + f"Expected positive mean batch_corr on linear task after training, " + f"got {mean_corr:.4f}" + ) + + +# --------------------------------------------------------------------------- +# Test 6: correlation computation does not affect gradient graph +# --------------------------------------------------------------------------- + +def test_batch_corr_does_not_affect_gradients(): + """Gradient graph must be unaffected by the correlation computation. + + Verifies that the correlation block uses T.no_grad() + detach() and + that no new requires_grad tensors are introduced that would accumulate + across calls (which would break the existing loss.backward() path). + """ + torch.manual_seed(5) + np.random.seed(5) + actor = make_gsp_actor() + _fill_gsp_buffer(actor, seed=5) + _fill_primary_buffer(actor) + + # Capture param grad norms before + net = actor.gsp_networks["actor"] + for p in net.parameters(): + p.grad = None + + # Run one learn step — if correlation leaks a graph, backward will raise + loss_val, batch_corr = actor.learn_gsp_mse(actor.gsp_networks) + + # Parameters should have gradients applied and zeroed by the optimizer step. + # The optimizer.zero_grad() was called at the start of learn_gsp_mse, and + # optimizer.step() was called after backward. Grads are NOT zeroed after step + # in PyTorch by default — they should be non-None but finite. + for name, p in net.named_parameters(): + if p.grad is not None: + assert not torch.isnan(p.grad).any(), ( + f"NaN gradient on {name} after learn_gsp_mse" + ) + assert not torch.isinf(p.grad).any(), ( + f"Inf gradient on {name} after learn_gsp_mse" + ) diff --git a/tests/test_jepa.py b/tests/test_jepa.py new file mode 100644 index 0000000..ea885db --- /dev/null +++ b/tests/test_jepa.py @@ -0,0 +1,150 @@ +"""Tests for the JEPA (Joint Embedding Predictive Architecture) modules. + +Covers: +1. test_encoder_predictor_shape — JEPAEncoder and JEPAPredictor produce the + expected output shapes for a batch of inputs. +2. test_target_ema_update — After one EMA step with tau=0.5 and known + initial weights, the target encoder parameters equal + 0.5 * old_target + 0.5 * online (i.e. arithmetic mean). +""" + +import copy + +import torch +import pytest + +from gsp_rl.src.networks.jepa import JEPAEncoder, JEPAPredictor + + +INPUT_DIM = 6 +LATENT_DIM = 32 +BATCH_SIZE = 16 + + +class TestEncoderPredictorShape: + """Shape contracts for JEPAEncoder and JEPAPredictor.""" + + def test_encoder_output_shape(self): + enc = JEPAEncoder(input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden=128) + x = torch.randn(BATCH_SIZE, INPUT_DIM).to(enc.device) + z = enc(x) + assert z.shape == (BATCH_SIZE, LATENT_DIM), ( + f"Expected encoder output shape ({BATCH_SIZE}, {LATENT_DIM}), got {z.shape}" + ) + + def test_predictor_output_shape(self): + pred = JEPAPredictor(latent_dim=LATENT_DIM, hidden=64) + z = torch.randn(BATCH_SIZE, LATENT_DIM).to(pred.device) + z_pred = pred(z) + assert z_pred.shape == (BATCH_SIZE, LATENT_DIM), ( + f"Expected predictor output shape ({BATCH_SIZE}, {LATENT_DIM}), got {z_pred.shape}" + ) + + def test_encoder_single_sample(self): + """Single-sample (batch=1) forward should not crash.""" + enc = JEPAEncoder(input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden=128) + x = torch.randn(1, INPUT_DIM).to(enc.device) + z = enc(x) + assert z.shape == (1, LATENT_DIM) + + def test_encoder_no_nan(self): + """Encoder output must be finite on random input.""" + enc = JEPAEncoder(input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden=128) + x = torch.randn(BATCH_SIZE, INPUT_DIM).to(enc.device) + z = enc(x) + assert torch.isfinite(z).all(), "Encoder output contains NaN or Inf" + + def test_predictor_no_nan(self): + """Predictor output must be finite on random latent.""" + pred = JEPAPredictor(latent_dim=LATENT_DIM, hidden=64) + z = torch.randn(BATCH_SIZE, LATENT_DIM).to(pred.device) + z_pred = pred(z) + assert torch.isfinite(z_pred).all(), "Predictor output contains NaN or Inf" + + +class TestTargetEmaUpdate: + """EMA update correctness for the target encoder.""" + + def test_ema_tau_half(self): + """With tau=0.5, after one EMA step: + target_p = 0.5 * old_target + 0.5 * online_p + which equals the arithmetic mean of old and online weights. + """ + tau = 0.5 + enc_online = JEPAEncoder(input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden=128) + enc_target = copy.deepcopy(enc_online) + + # Perturb online weights so online != target + with torch.no_grad(): + for p in enc_online.parameters(): + p.add_(torch.ones_like(p) * 2.0) # shift online by +2 + + # Capture old target weights and online weights + old_target_params = {name: p.data.clone() for name, p in enc_target.named_parameters()} + online_params = {name: p.data.clone() for name, p in enc_online.named_parameters()} + + # Perform EMA update + with torch.no_grad(): + for online_p, target_p in zip( + enc_online.parameters(), enc_target.parameters() + ): + target_p.data.mul_(tau).add_(online_p.data, alpha=1.0 - tau) + + # Verify: each target param = 0.5 * old_target + 0.5 * online + for name, target_p in enc_target.named_parameters(): + expected = 0.5 * old_target_params[name] + 0.5 * online_params[name] + assert torch.allclose(target_p.data, expected, atol=1e-6), ( + f"EMA mismatch at param '{name}': " + f"expected mean of old+online, got divergent values" + ) + + def test_ema_tau_one_freezes_target(self): + """With tau=1.0, the target should not change (fully frozen update).""" + tau = 1.0 + enc_online = JEPAEncoder(input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden=128) + enc_target = copy.deepcopy(enc_online) + + # Snapshot target before update + old_target_params = {name: p.data.clone() for name, p in enc_target.named_parameters()} + + # Perturb online + with torch.no_grad(): + for p in enc_online.parameters(): + p.add_(torch.ones_like(p) * 5.0) + + # EMA with tau=1.0 + with torch.no_grad(): + for online_p, target_p in zip( + enc_online.parameters(), enc_target.parameters() + ): + target_p.data.mul_(tau).add_(online_p.data, alpha=1.0 - tau) + + for name, target_p in enc_target.named_parameters(): + assert torch.allclose(target_p.data, old_target_params[name], atol=1e-6), ( + f"With tau=1.0, target param '{name}' should be unchanged" + ) + + def test_ema_tau_zero_copies_online(self): + """With tau=0.0, the target should become a copy of online.""" + tau = 0.0 + enc_online = JEPAEncoder(input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden=128) + enc_target = copy.deepcopy(enc_online) + + # Perturb online + with torch.no_grad(): + for p in enc_online.parameters(): + p.add_(torch.ones_like(p) * 3.0) + + online_params = {name: p.data.clone() for name, p in enc_online.named_parameters()} + + # EMA with tau=0.0 + with torch.no_grad(): + for online_p, target_p in zip( + enc_online.parameters(), enc_target.parameters() + ): + target_p.data.mul_(tau).add_(online_p.data, alpha=1.0 - tau) + + for name, target_p in enc_target.named_parameters(): + assert torch.allclose(target_p.data, online_params[name], atol=1e-6), ( + f"With tau=0.0, target param '{name}' should equal online" + )