Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 122 additions & 16 deletions gsp_rl/src/actors/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AttentionSequenceReplayBuffer
)
from gsp_rl.src.actors.learning_aids import NetworkAids
from gsp_rl.src.networks.jepa import JEPAEncoder, JEPAPredictor


class Actor(NetworkAids):
Expand Down Expand Up @@ -132,21 +133,67 @@ def __init__(

self.network_input_size = self.input_size
if self.gsp:
# For multi-dim GSP output (gsp_output_kind != delta_theta_1d),
# the actor's augmented obs grows by the full output vector width,
# not just 1. agent.py's make_agent_state handles the concatenation.
self.network_input_size += self.gsp_network_output
if getattr(self, 'gsp_jepa_enabled', False):
# JEPA path: actor receives the full latent vector (encoder_dim)
# instead of the legacy scalar/K-dim GSP prediction.
self.network_input_size += getattr(self, 'gsp_encoder_dim', 32)
else:
# For multi-dim GSP output (gsp_output_kind != delta_theta_1d),
# the actor's augmented obs grows by the full output vector width,
# not just 1. agent.py's make_agent_state handles the concatenation.
self.network_input_size += self.gsp_network_output
if self.attention_gsp:
self.attention_observation = [[0 for _ in range(self.gsp_network_input)] for _ in range(self.gsp_sequence_length)]
elif self.recurrent_gsp:
self.recurrent_gsp_network_input = self.gsp_network_input

self.build_networks(network)
self.gsp_networks = None

# JEPA latent-space encoder path. When enabled, skip the legacy
# DDPG-based GSP build and instantiate encoder + predictor instead.
# The target encoder is an EMA copy of the online encoder (frozen).
self.gsp_encoder_online = None
self.gsp_encoder_target = None
self.gsp_predictor = None
self._jepa_online_optimizer = None
self._jepa_predictor_optimizer = None

if gsp:
if attention:
self.build_gsp_network('attention')
self.build_gsp_network('DDPG')
if getattr(self, 'gsp_jepa_enabled', False):
import copy
_enc_dim = getattr(self, 'gsp_encoder_dim', 32)
_head_lr = getattr(self, 'gsp_head_lr', self.lr)
self.gsp_encoder_online = JEPAEncoder(
input_dim=self.gsp_network_input,
latent_dim=_enc_dim,
)
# Target encoder: EMA copy — weights frozen, no gradient
self.gsp_encoder_target = copy.deepcopy(self.gsp_encoder_online)
for param in self.gsp_encoder_target.parameters():
param.requires_grad = False
self.gsp_predictor = JEPAPredictor(latent_dim=_enc_dim)
# Optimizers: online encoder + predictor share one optimizer
self._jepa_online_optimizer = T.optim.Adam(
list(self.gsp_encoder_online.parameters())
+ list(self.gsp_predictor.parameters()),
lr=_head_lr,
)
# JEPA requires a replay buffer for (state_t, state_{t+k}) pairs.
# We reuse the standard ReplayBuffer; the future state is stored
# in the 'action' slot by convention (matches future_prox label path).
self.gsp_networks = {}
self.gsp_networks['learning_scheme'] = 'JEPA'
self.gsp_networks['replay'] = ReplayBuffer(
self.mem_size, self.gsp_network_input,
self.gsp_network_input, 'Continuous'
)
self.gsp_networks['learn_step_counter'] = 0
self.gsp_networks['output_size'] = _enc_dim
else:
if attention:
self.build_gsp_network('attention')
self.build_gsp_network('DDPG')

# Information-collapse diagnostic: last GSP learner training loss.
# NOTE: this is the loss returned by the GSP learner's inner learn step, which means:
Expand All @@ -164,6 +211,17 @@ def __init__(
self.last_gsp_loss: float | None = None
# Populated by learn() when the e2e path fires; reset to None each learn() call.
self.last_e2e_diagnostics = None
# Phase 4 loss-step correlation diagnostic. Accumulates one float per
# GSP learn step (the Pearson corr between fresh forward-pass preds and
# replay-buffer labels for that batch). Collected by Main.py at episode
# end to produce mean/std attrs in HDF5. Main.py is responsible for
# clearing this list after consuming it (not reset per-tick like
# last_gsp_loss, because it accumulates across an episode).
self.last_gsp_loss_step_corr_samples: list = []
# JEPA latent stats from the most recent learn_gsp_jepa call. Dict with
# keys {var, rank, pred_mse}. Reset to None each learn() tick (like
# last_gsp_loss). Main.py reads this to call hdf5_writer.record_jepa_*.
self.last_gsp_jepa_stats: dict | None = None

def build_networks(self, learning_scheme):
if learning_scheme == 'None':
Expand Down Expand Up @@ -337,7 +395,13 @@ def build_DDPG_gsp(self):
'id':self.id,
'input_size':self.gsp_network_input,
'output_size':self.gsp_network_output,
'lr': self.lr,
# Phase 4: use gsp_head_lr (independent of trunk/actor LR).
# Default: same as self.lr (from config['LR']), so existing batches are
# bit-for-bit identical. Override via GSP_HEAD_LR in the experiment YAML.
# The GSP critic LR intentionally stays at self.lr — only the actor/head
# that produces predictions (and is trained via supervised MSE) gets the
# independent rate.
'lr': getattr(self, 'gsp_head_lr', self.lr),
'min_max_action':self.min_max_action,
# Task 0 ablation knobs — defaults preserve legacy behavior exactly.
# Only the GSP-head actor network receives these; the main policy actor
Expand Down Expand Up @@ -480,15 +544,22 @@ def learn(self):
# Reset per-tick diagnostic signals. None means "no step ran this tick".
self.last_gsp_loss = None
self.last_e2e_diagnostics = None
self.last_gsp_jepa_stats = None

# TODO Not sure why we have n_agents*batch_size + batch_size
if self.networks['replay'].mem_ctr < self.batch_size: # (self.n_agents*self.batch_size + self.batch_size):
return

if self.gsp:
if self.networks['learn_step_counter'] % self.gsp_learning_offset == 0:
#print('[DEBUG] Learning Attention', self.networks['learn_step_counter'])
self.learn_gsp()
# H-phase5-4: when GSP_HEAD_FROZEN is true, skip the GSP head's
# optimizer step entirely. Head stays at random init for the
# entire run. Tests whether reward-shaping wins require head
# learning at all, or whether reward density alone explains the
# effect. Default false preserves all prior behavior.
if not getattr(self, 'gsp_head_frozen', False):
if self.networks['learn_step_counter'] % self.gsp_learning_offset == 0:
#print('[DEBUG] Learning Attention', self.networks['learn_step_counter'])
self.learn_gsp()

if self.networks['learning_scheme'] == 'DDQN' and getattr(self, 'gsp_e2e_enabled', False) and self.gsp:
self.replace_target_network()
Expand Down Expand Up @@ -529,18 +600,36 @@ def learn_gsp(self):
# for root cause analysis.
loss = None
scheme = self.gsp_networks['learning_scheme']
if scheme == 'attention':
if scheme == 'JEPA':
loss = self.learn_gsp_jepa(self.gsp_networks)
elif scheme == 'attention':
loss = self.learn_attention(self.gsp_networks)
elif scheme == 'RDDPG':
loss = self.learn_gsp_mse(self.gsp_networks, recurrent=True)
elif scheme in {'DDPG', 'TD3'}:
loss = self.learn_gsp_mse(self.gsp_networks, recurrent=False)
if loss is not None:
# Keep the tuple-skip guard for safety in case learn_attention's
# return type ever changes; learn_gsp_mse returns a plain float.
# learn_gsp_mse returns (loss_float, batch_corr_float).
# learn_gsp_jepa returns (loss_float, latent_stats_dict).
# learn_attention returns a plain float — keep the tuple dispatch.
if isinstance(loss, tuple):
return
self.last_gsp_loss = float(loss)
loss_val = loss[0]
self.last_gsp_loss = float(loss_val)
extra = loss[1]
if isinstance(extra, dict):
# JEPA path: store latent stats for Main.py to record.
self.last_gsp_jepa_stats = extra
elif isinstance(extra, float):
batch_corr = extra
# Accumulate per-batch loss-step correlations across all GSP learn
# steps within this episode. Main.py reads
# last_gsp_loss_step_corr_samples at episode end, computes
# mean/std, and passes them to hdf5_writer. Attribute is
# initialised in __init__ and cleared by Main.py at episode end.
if not math.isnan(batch_corr):
self.last_gsp_loss_step_corr_samples.append(batch_corr)
else:
self.last_gsp_loss = float(loss)

def store_agent_transition(self, s, a, r, s_, d, gsp_obs=None, gsp_label=None):
self.store_transition(s, a, r, s_, d, self.networks, gsp_obs=gsp_obs, gsp_label=gsp_label)
Expand Down Expand Up @@ -912,6 +1001,15 @@ def save_model(self, path):
if self.attention_gsp:
if self.gsp_networks['learning_scheme'] == 'attention':
self.gsp_networks['attention'].save_checkpoint(path)
elif self.gsp and getattr(self, 'gsp_jepa_enabled', False):
# JEPA path: save encoder_online + predictor + target_encoder via torch.save
import torch
jepa_state = {
'encoder_online': self.gsp_encoder_online.state_dict(),
'encoder_target': self.gsp_encoder_target.state_dict(),
'predictor': self.gsp_predictor.state_dict(),
}
torch.save(jepa_state, f"{path}_jepa.pt")
elif self.gsp:
self.gsp_networks['actor'].save_checkpoint(path, self.gsp)
self.gsp_networks['target_actor'].save_checkpoint(path, self.gsp)
Expand Down Expand Up @@ -947,6 +1045,14 @@ def load_model(self, path):
if self.attention_gsp:
if self.gsp_networks['learning_scheme'] == 'attention':
self.gsp_networks['attention'].load_checkpoint(path)
elif self.gsp and getattr(self, 'gsp_jepa_enabled', False):
import os, torch
jepa_path = f"{path}_jepa.pt"
if os.path.exists(jepa_path):
state = torch.load(jepa_path, map_location='cpu')
self.gsp_encoder_online.load_state_dict(state['encoder_online'])
self.gsp_encoder_target.load_state_dict(state['encoder_target'])
self.gsp_predictor.load_state_dict(state['predictor'])
elif self.gsp:
self.gsp_networks['actor'].load_checkpoint(path, self.gsp)
self.gsp_networks['target_actor'].load_checkpoint(path, self.gsp)
Expand Down
156 changes: 155 additions & 1 deletion gsp_rl/src/actors/learning_aids.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ def __init__(self, config):
self.beta = config['BETA']
self.lr = config['LR']

# Phase 4 — independent GSP head learning rate.
# Default: same value as the trunk/actor LR (config['LR']), preserving
# exact legacy behavior for all existing batches. When set to a different
# value, the GSP prediction head's Adam optimizer uses gsp_head_lr while
# the main action-network optimizer continues to use self.lr.
# Only affects the GSP actor/head network; the GSP critic and target
# networks are unchanged (they remain tied to self.lr).
self.gsp_head_lr = float(config.get('GSP_HEAD_LR', self.lr))

self.epsilon = config['EPSILON']
self.eps_min = config['EPS_MIN']
self.eps_dec = config['EPS_DEC']
Expand Down Expand Up @@ -317,6 +326,19 @@ def __init__(self, config):
self.update_actor_iter = config['UPDATE_ACTOR_ITER']
self.warmup = config['WARMUP']
self.time_step = 0
# H-phase5-4: when True, skip the GSP head's optimizer step entirely.
# Head stays at random init for the run. Default False preserves all
# prior behavior. Read in actor.py:502 in the learn() loop.
self.gsp_head_frozen = bool(config.get('GSP_HEAD_FROZEN', False))

# JEPA (Joint Embedding Predictive Architecture) latent-space head.
# When enabled, the legacy scalar future_prox prediction is replaced
# by: online encoder → predictor → latent MSE against EMA target encoder.
# The actor receives the 32-d encoder output instead of the 1-d gsp_pred.
# Default False — all existing runs are unaffected.
self.gsp_jepa_enabled = bool(config.get('GSP_JEPA_ENABLED', False))
self.gsp_encoder_dim = int(config.get('GSP_ENCODER_DIM', 32))
self.gsp_encoder_ema_tau = float(config.get('GSP_ENCODER_EMA_TAU', 0.995))

class NetworkAids(Hyperparameters):
"""Network factory, learning algorithms, action selection, and memory management.
Expand Down Expand Up @@ -972,7 +994,139 @@ def learn_gsp_mse(self, networks, recurrent: bool = False):
networks['actor'].optimizer.step()

networks['learn_step_counter'] += 1
return loss.item()

# Phase 4 — loss-step correlation diagnostic.
# Compute Pearson correlation between the FRESH forward-pass predictions
# (the same preds that produced the MSE loss) and the replay-buffer labels.
# This is intentionally different from gsp_pred_target_corr in hdf5_logger,
# which accumulates actor-input-path predictions over a full episode (a
# different code path with a 1-timestep lag). Computing per-batch here and
# aggregating in the caller lets us compare "is the loss-path head actually
# learning?" vs "is the actor-input path measurement broken?"
#
# Safety contract:
# - Uses T.no_grad() / detach — zero gradient graph impact.
# - NaN/zero-variance guard: returns float("nan") when undefined.
# - Shape agnostic: flattens both arrays before corrcoef.
# - Recurrent path: preds/labels_shaped not available in that scope,
# so we skip and return nan for consistency.
batch_corr: float = float("nan")
if not recurrent:
with T.no_grad():
_pred_np = preds.detach().cpu().numpy().flatten()
_lbl_np = labels_shaped.detach().cpu().numpy().flatten()
if _pred_np.size > 1:
_STD_TOL = 1e-12
_p_std = float(np.nanstd(_pred_np))
_l_std = float(np.nanstd(_lbl_np))
if _p_std > _STD_TOL and _l_std > _STD_TOL:
_mask = np.isfinite(_pred_np) & np.isfinite(_lbl_np)
if _mask.sum() > 1:
batch_corr = float(np.corrcoef(_pred_np[_mask], _lbl_np[_mask])[0, 1])

return loss.item(), batch_corr

def _update_jepa_target_encoder(self, tau: float) -> None:
"""EMA update: target_p ← tau * target_p + (1 - tau) * online_p.

Args:
tau: EMA decay coefficient (e.g. 0.995). Higher = slower update.
"""
with T.no_grad():
for online_p, target_p in zip(
self.gsp_encoder_online.parameters(),
self.gsp_encoder_target.parameters(),
):
target_p.data.mul_(tau).add_(online_p.data, alpha=1.0 - tau)

def learn_gsp_jepa(self, networks: dict):
"""Train the JEPA latent-space GSP head.

Samples (state_t, state_{t+k}) pairs from the JEPA replay buffer
(state_t in the 'state' slot, state_{t+k} in the 'action' slot by
convention). Computes:

z_t = encoder_online(state_t) # online encoding
z_pred = predictor(z_t) # predicted future latent
z_target = encoder_target(state_{t+k}).detach() # EMA target

loss_pred = MSE(z_pred, z_target) # latent prediction loss

Optional VICReg variance + covariance penalties on z_t are added
when self.gsp_vicreg_enabled is True (reusing existing helpers).

After backward + optimizer step, the target encoder is updated via EMA.

Returns:
Tuple (loss_float, latent_stats_dict) where latent_stats_dict has:
{var: float, rank: float, pred_mse: float}
"""
if networks['replay'].mem_ctr < self.gsp_batch_size:
return None

vicreg_enabled = getattr(self, 'gsp_vicreg_enabled', False)
tau = float(getattr(self, 'gsp_encoder_ema_tau', 0.995))
enc_device = self.gsp_encoder_online.device

# Sample directly from the JEPA replay buffer rather than going through
# sample_memory(), which requires a 'actor' or 'q_eval' key in networks
# to determine device. JEPA networks dict has neither — device comes from
# the encoder module itself.
result = networks['replay'].sample_buffer(self.gsp_batch_size)
raw_states, raw_future, _, _, _ = result[0], result[1], result[2], result[3], result[4]
states = T.tensor(raw_states, dtype=T.float32).to(enc_device)
# future_states: stored in the 'action' slot by convention (state_{t+k})
future_states = T.tensor(raw_future, dtype=T.float32).to(enc_device)

# Forward through online encoder + predictor
z_t = self.gsp_encoder_online(states)
z_pred = self.gsp_predictor(z_t)

# Target: forward through frozen target encoder
with T.no_grad():
z_target = self.gsp_encoder_target(future_states)

loss_pred = F.mse_loss(z_pred, z_target)
loss = loss_pred

# Optional VICReg on online encoder output z_t
if vicreg_enabled:
var_coef = float(getattr(self, 'gsp_vicreg_var_coef', 1.0))
cov_coef = float(getattr(self, 'gsp_vicreg_cov_coef', 0.04))
# target_std: 1.0 (standard VICReg default) — latent lives in
# unbounded linear space so label-std normalization is not needed.
var_loss = vicreg_variance_loss(z_t, target_std=1.0)
cov_loss = vicreg_covariance_loss(z_t)
loss = loss_pred + var_coef * var_loss + cov_coef * cov_loss

self._jepa_online_optimizer.zero_grad()
loss.backward()
_check_nan(loss, f"JEPA loss at step {networks['learn_step_counter']}")
self._jepa_online_optimizer.step()

# EMA update of target encoder
self._update_jepa_target_encoder(tau)

networks['learn_step_counter'] += 1

# Compute latent statistics (no grad)
with T.no_grad():
latent_var = float(z_t.var(dim=0).mean().item())
# Approximate rank: number of singular values above 1% of max
z_cpu = z_t.detach().cpu()
try:
sv = T.linalg.svdvals(z_cpu)
rank = float((sv > sv[0] * 0.01).sum().item())
except Exception:
rank = float("nan")
pred_mse = float(loss_pred.item())

latent_stats = {
'var': latent_var,
'rank': rank,
'pred_mse': pred_mse,
}
return loss.item(), latent_stats

def decrement_epsilon(self):
self.epsilon = max(self.epsilon-self.eps_dec, self.eps_min)
Expand Down
Loading