diff --git a/autoregressive_codice_prova.py b/autoregressive_codice_prova.py new file mode 100644 index 000000000..69b194d35 --- /dev/null +++ b/autoregressive_codice_prova.py @@ -0,0 +1,185 @@ +import torch +import matplotlib.pyplot as plt + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver + +NUM_TIMESTEPS = 100 +NUM_FEATURES = 15 +USE_TEST_MODEL = False + +# ============================================================================ +# DATA +# ============================================================================ + +torch.manual_seed(42) + +y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) +y[0] = torch.rand(NUM_FEATURES) # Random initial state + +for t in range(NUM_TIMESTEPS - 1): + y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum()) + +# ============================================================================ +# TRAINING +# ============================================================================ + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(y.shape[1], 15), + torch.nn.Tanh(), + # torch.nn.Dropout(0.1), + torch.nn.Linear(15, y.shape[1]), + ) + + def forward(self, x): + return x + self.layers(x) + + +class TestModel(torch.nn.Module): + """ + Debug model that implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True)) + return next_state + 0.0 * self.dummy_param + +# create a problem with duplicated data conditions +class Problem(AbstractProblem): + output_variables = None + input_variables = None + + # create two different unroll datasets: short and medium + y_short = AutoregressiveSolver.unroll( + y, unroll_length=4, num_unrolls=20, randomize=False + ) + y_medium = AutoregressiveSolver.unroll( + y, unroll_length=10, num_unrolls=15, randomize=False + ) + y_long = AutoregressiveSolver.unroll( + y, unroll_length=20, num_unrolls=10, randomize=False + ) + + conditions = {} + + inactive_conditions = { + "short": DataCondition(input=y_short), + "medium": DataCondition(input=y_medium), + "long": DataCondition(input=y_long), + } + + # Settings kept separate from the DataCondition objects + conditions_settings = { + "short": {"eps": 0.1}, + "medium": {"eps": 1.0}, + "long": {"eps": 2.0}, + } + + +problem = Problem() + +# helper that allows to activate or replace a condition at runtime +def activate_condition(problem, name, data=None, settings=None): + """ + Activate a single condition by name. + + `conditions_settings` is left untouched unless `settings` is explicitly + provided and no entry exists yet for `name`. + """ + # if data is provided, (re)register condition in inactive store + if data is not None: + problem.inactive_conditions[name] = DataCondition(input=data) + + problem.conditions = {} + problem.conditions[name] = problem.inactive_conditions[name] + + if settings is not None: + problem.conditions_settings[name] = settings + +# configure solver and trainer +solver = AutoregressiveSolver( + problem=problem, + model=TestModel() if USE_TEST_MODEL else SimpleModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.011), +) + + +print("Beginning phase 1: training with 'short' condition only") +activate_condition(problem, "short") +trainer1 = Trainer(solver, max_epochs=300, accelerator="cpu", enable_model_summary=False) +trainer1.train() + +print("Beginning phase 2: training with 'medium' condition added") +activate_condition(problem, "medium") +trainer2 = Trainer(solver, max_epochs=500, accelerator="cpu", enable_model_summary=False) +trainer2.train() + +print("Beginning phase 3: training with 'long' condition added") +activate_condition(problem, "long") +trainer3 = Trainer(solver, max_epochs=900, accelerator="cpu", enable_model_summary=False) +trainer3.train() + + +# ============================================================================ +test_start_idx = 50 +num_prediction_steps = 49 +initial_state = y[test_start_idx] # Shape: [features] +predictions = solver.predict(initial_state, num_prediction_steps) +actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1] + +print("\n=== PREDICTION DEBUG ===") +for i in range(min(10, num_prediction_steps)): + pred_val = predictions[i].mean().item() + actual_val = actual[i].mean().item() + error = (predictions[i] - actual[i]).abs().mean().item() + print(f"Step {i}: pred={pred_val:.4f}, actual={actual_val:.4f}, error={error:.4f}") + +total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:]) +print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}") + +# visualize single dof +dof_to_plot = [0, 3, 6, 9, 12] +colors = [ + "r", + "g", + "b", + "c", + "m", + "y", + "k", +] +plt.figure(figsize=(10, 6)) +for dof, color in zip(dof_to_plot, colors): + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + actual[:, dof].numpy(), + label="Actual", + marker="o", + color=color, + markerfacecolor="none", + ) + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + predictions[:, dof].numpy(), + label="Predicted", + marker="x", + color=color, + ) + +plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}") +plt.legend() +plt.xlabel("Timestep") +plt.savefig(f"autoregressive_predictions.png") +plt.close() diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index 43f18078f..e7d48e2b3 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -18,6 +18,7 @@ "DeepEnsembleSupervisedSolver", "DeepEnsemblePINN", "GAROM", + "AutoregressiveSolver", ] from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface @@ -41,3 +42,7 @@ DeepEnsemblePINN, ) from .garom import GAROM +from .autoregressive_solver import ( + AutoregressiveSolver, + AutoregressiveSolverInterface, +) diff --git a/pina/solver/autoregressive_solver/__init__.py b/pina/solver/autoregressive_solver/__init__.py new file mode 100644 index 000000000..9ef7c43e1 --- /dev/null +++ b/pina/solver/autoregressive_solver/__init__.py @@ -0,0 +1,4 @@ +__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"] + +from .autoregressive_solver import AutoregressiveSolver +from .autoregressive_solver_interface import AutoregressiveSolverInterface diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py new file mode 100644 index 000000000..e377804c7 --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -0,0 +1,305 @@ +import torch +from pina.utils import check_consistency +from pina.solver.solver import SingleSolverInterface +from pina.condition import DataCondition +from .autoregressive_solver_interface import AutoregressiveSolverInterface +from typing import List +import logging + +logger = logging.getLogger(__name__) + + +class AutoregressiveSolver( + AutoregressiveSolverInterface, SingleSolverInterface +): + r""" + Autoregressive Solver for learning dynamical systems. + + This solver learns a one-step transition function + :math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps + a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`. + + During training, the model is unrolled over multiple time steps to + learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`, + the model generates predictions recursively: + + .. math:: + \hat{\mathbf{y}}_{t+1} = \mathcal{M}(\hat{\mathbf{y}}_t), + \quad \hat{\mathbf{y}}_0 = \mathbf{y}_0 + + The loss is computed over the entire unroll window: + + .. math:: + \mathcal{L} = \sum_{t=1}^{T} w_t \|\hat{\mathbf{y}}_t - \mathbf{y}_t\|^2 + + where :math:`w_t` are exponential weights (if ``eps`` is specified) + that down-weight later predictions to stabilize training. + """ + + accepted_conditions_types = DataCondition + + def __init__( + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=False, + N_epochs_with_same_weights=10, + ): + """ + Initialization of the :class:`AutoregressiveSolver` class. + + :param AbstractProblem problem: The problem instance containing + the time series data conditions. + :param torch.nn.Module model: Neural network that predicts the + next state given the current state. + :param torch.nn.Module loss: Loss function to minimize. + If ``None``, :class:`torch.nn.MSELoss` is used. + Default is ``None``. + :param TorchOptimizer optimizer: Optimizer for training. + If ``None``, :class:`torch.optim.Adam` is used. + Default is ``None``. + :param TorchScheduler scheduler: Learning rate scheduler. + If ``None``, no scheduling is applied. Default is ``None``. + :param WeightingInterface weighting: Weighting scheme for + combining losses from multiple conditions. + If ``None``, uniform weighting is used. Default is ``None``. + :param bool use_lt: Whether to use LabelTensors. + Default is ``False``. + :param int N_epochs_with_same_weights: Number of epochs to keep the same adaptive weights + before recomputing them. Default is ``10``. + """ + + super().__init__( + problem=problem, + model=model, + loss=loss, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + # cache for per-condition adaptive weights and epoch-based update control + # this is the most generic way to implement periodic weight updates I found + self._cached_weights = {} + self._epochs_since_update = 0 + self.N_epochs_with_same_weights = N_epochs_with_same_weights + + @staticmethod + def unroll( + data, unroll_length: int, num_unrolls=None, randomize: bool = True + ): + """ + Create unroll windows from time series data. + + This is a pre-processing step. It slices the input time series into + overlapping windows of length ``Twin = unroll_length + 1`` along the + time axis (axis 0). Each window contains the initial state and the + subsequent target states used to compute a multi-step loss. + + :param torch.Tensor data: Time series tensor with shape ``[T, *state_shape]``. + The first axis is interpreted as time. + :param int unroll_length: Number of transitions in each window. + Each window has length ``unroll_length + 1``. + :param int num_unrolls: Maximum number of windows to return. If ``None``, + all valid windows are returned. Default is ``None``. + :param bool randomize: If ``True``, starting indices are randomly permuted + before applying ``num_unrolls``. Default is ``True``. + :return: Tensor of unroll windows with shape ``[Nw, unroll_length + 1, *state_shape]``. + If no valid windows exist, returns an empty tensor with shape + ``[0, unroll_length + 1, *state_shape]``. + :rtype: torch.Tensor + """ + starts = AutoregressiveSolver.decide_starting_indices( + data, unroll_length, num_unrolls, randomize + ) + if starts.numel() == 0: + return torch.empty( + (0, unroll_length + 1, *data.shape[1:]), device=data.device + ) + + windows = [ + data[int(start) : int(start) + unroll_length + 1] + for start in starts + ] + + return torch.stack( + windows, dim=0 + ) # [num_unrolls, unroll_length + 1, *data.shape[1:]] + + @staticmethod + def decide_starting_indices( + data, unroll_length, num_unrolls=None, randomize=True + ): + """ + Determine starting indices for unroll windows. + + Computes valid starting positions ensuring each window has enough + subsequent time steps for the specified unroll length. + + :param torch.Tensor data: Time series tensor with shape ``[T, *state_shape]``. + :param int unroll_length: Number of transitions in each window. + :param int num_unrolls: Maximum number of indices to return. If ``None``, + all valid indices are returned. Default is ``None``. + :param bool randomize: If ``True``, indices are randomly permuted before + applying ``num_unrolls``. Default is ``True``. + :return: 1D tensor of starting indices with dtype ``torch.long``. + :rtype: torch.Tensor + """ + n_step = int(data.shape[0]) + twin = int(unroll_length + 1) + last_start = n_step - twin + if last_start < 0: + return torch.empty(0, dtype=torch.long, device=data.device) + + indices = torch.arange(last_start + 1, device=data.device) + + if randomize: + indices = indices[torch.randperm(len(indices), device=data.device)] + + if num_unrolls is not None and num_unrolls < len(indices): + indices = indices[:num_unrolls] + + return indices + + def loss_data(self, unroll, eps=None, aggregation_strategy=None, condition_name=None): + """ + Compute the autoregressive multi-step data loss. + + The input ``unroll`` is expected to be a batch of precomputed unroll windows + with shape ``[B, Twin, *state_shape]``. The first element along the ``Twin`` + axis is used as current state, and the following elements are the targets. + + :param torch.Tensor unroll: Batch of unroll windows with shape + ``[B, Twin, *state_shape]`` where ``Twin = unroll_length + 1``. + :param float eps: If provided, applies step weighting through + :meth:`weighting_strategy`. If ``None``, uniform normalized weights are used. + Default is ``None``. + :param callable aggregation_strategy: Reduction applied to the weighted per-step + losses. If ``None``, :func:`torch.sum` is used. Default is ``None``. + :return: Scalar loss value for the given batch. + :rtype: torch.Tensor + """ + # batch dimensition is unroll.shape[0] -the number of unrolls- + Twin = unroll.shape[1] + + current_state = unroll[:, 0, ...] # first time step of each batch + losses = [] + for step in range(1, Twin): + + predicted_state = self.forward( + current_state + ) # [num_unrolls, features] + target_state = unroll[:, step, ...] # [num_unrolls, features] + step_loss = self._loss_fn(predicted_state, target_state) + losses.append(step_loss) + + if logger.isEnabledFor(logging.DEBUG) and (step <= 3 or torch.isnan(step_loss)): + logger.debug( + " Step %d: loss=%.4e, pred=[%.3f, %.3f]", + step, + float(step_loss.item()), + float(predicted_state.min()), + float(predicted_state.max()), + ) + + current_state = predicted_state + + step_losses = torch.stack(losses) # [unroll_length] + + with torch.no_grad(): + condition_name = condition_name or "default" + weights = self.get_weights(condition_name, step_losses, eps) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(" Losses: %s", step_losses.detach().cpu().numpy().round(4)) + logger.debug(" Weights: %s", weights.cpu().numpy().round(4)) + logger.debug(" Weight ratio: %.1f", float(weights.max() / weights.min())) + + if aggregation_strategy is None: + aggregation_strategy = torch.sum + + return aggregation_strategy(step_losses * weights) + + def _compute_adaptive_weights(self, step_losses, eps): + """ + Actual computation of adaptive weights. + :param torch.Tensor step_losses: 1D tensor of per-step losses. + :param float eps: Weighting parameter. + :return: Computed weights tensor. + :rtype: torch.Tensor + """ + print(f"updating weights, eps={eps}") + + if eps is None: + return torch.ones_like(step_losses) / step_losses.numel() + + log_w = torch.clamp(-eps * torch.cumsum(step_losses, dim=0), -20, 20) + return torch.softmax(log_w, dim=0) + + def get_weights(self, condition_name, step_losses, eps): + """ + Return cached weights or compute new ones. + :param str condition_name: Name of the condition. + :param torch.Tensor step_losses: 1D tensor of per-step losses. + :param float eps: Weighting parameter. + :return: Weights tensor. + :rtype: torch.Tensor + """ + cached = self._cached_weights.get(condition_name, None) + if cached is None: + cached = self._compute_adaptive_weights(step_losses, eps).cpu() + self._cached_weights[condition_name] = cached + return cached.to(step_losses.device) + + def on_train_epoch_end(self): + """ + Hook called by Lightning at the end of each epoch. + Forces periodic recalculation of weights by clearing the cache. + """ + self._epochs_since_update += 1 + if self._epochs_since_update >= self.N_epochs_with_same_weights: + self._cached_weights.clear() + self._epochs_since_update = 0 + + def predict(self, initial_state, num_steps): + """ + Generate predictions by recursively applying the model. + + Starting from ``initial_state``, applies the model repeatedly to generate + a trajectory of length ``num_steps + 1`` (including the initial state). + + :param torch.Tensor initial_state: Starting state. Supported shapes: + - ``[n_features]`` (unbatched, 1D) + - ``[B, n_features]`` (batched) + More general tensors ``[*state_shape]`` / ``[B, *state_shape]`` are also + supported, provided the model can process them. + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory including the initial state. Shape: + - ``[num_steps + 1, *state_shape]`` if unbatched input + - ``[num_steps + 1, B, *state_shape]`` if batched input + :rtype: torch.Tensor + """ + self.eval() # Set model to evaluation mode + + current_state = initial_state + + added_batch = False + if current_state.dim() == 1: + current_state = current_state.unsqueeze(0) + added_batch = True + + predictions = [current_state] + with torch.no_grad(): + for step in range(num_steps): + next_state = self.forward(current_state) + predictions.append(next_state) + current_state = next_state + + out = torch.stack(predictions, dim=0) + if added_batch: + out = out[:, 0, ...] # remove batch dimension + + return out diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py new file mode 100644 index 000000000..788f6c081 --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -0,0 +1,105 @@ +"""Module for the Autoregressive solver interface.""" + +from abc import abstractmethod +import torch +from torch.nn.modules.loss import _Loss +from dataclasses import dataclass + +from ..solver import SolverInterface +from ...utils import check_consistency +from ...loss.loss_interface import LossInterface +from ...condition import DataCondition +from typing import Optional + + +class AutoregressiveSolverInterface(SolverInterface): + """ + Base class for autoregressive solvers. + + The training pipeline expects :class:`~pina.condition.data_condition.DataCondition` + conditions. In the recommended configuration, each DataCondition input is a + collection of unroll windows with shape ``[Nw, Twin, *state_shape]``, where + ``Twin = unroll_length + 1``. The Trainer batches along the first axis, producing + ``[B, Twin, *state_shape]`` tensors passed to :meth:`loss_data`. + """ + + def __init__(self, loss=None, **kwargs): + """ + Initialization of the :class:`AutoregressiveSolverInterface` class. + + :param torch.nn.Module loss: Loss function to minimize. If ``None``, + :class:`torch.nn.MSELoss` is used. Default is ``None``. + :param kwargs: Additional keyword arguments forwarded to + :class:`~pina.solver.solver.SolverInterface`. + """ + + super().__init__(**kwargs) + + if loss is None: + loss = torch.nn.MSELoss() + + check_consistency(loss, (LossInterface, _Loss), subclass=False) + self._loss_fn = loss + + def optimization_cycle(self, batch): + """ + Optimization cycle for this family of solvers. + Iterates over each condition and each time applies the specialized loss_data function. + + :param list[tuple[str, dict]] batch: List of tuples where each + tuple contains a condition name and a dictionary with the + ``"input"`` key mapping to the time series tensor. + :return: Dictionary mapping condition names to computed loss values. + :rtype: dict[str, torch.Tensor] + """ + + condition_loss = {} + for condition_name, points in batch: + settings = {} + if hasattr(self.problem, "conditions_settings"): + settings = self.problem.conditions_settings.get( + condition_name, {} + ) + + loss = self.loss_data( + points["input"], + eps=settings.get("eps"), + condition_name=condition_name, + ) + condition_loss[condition_name] = loss + return condition_loss + + @abstractmethod + def loss_data(self, input, **settings): + """ + Compute the data loss for each condition. + This method must be implemented by subclasses to define the + specific loss computation strategy. + """ + pass + + @abstractmethod + def predict(self, initial_state, num_steps): + """ + Generate predictions by recursively applying the model. + + :param torch.Tensor initial_state: Starting state. Supported shapes are: + - ``[*state_shape]`` (unbatched) + - ``[B, *state_shape]`` (batched) + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory including the initial state. Shape: + - ``[num_steps + 1, *state_shape]`` if unbatched input + - ``[num_steps + 1, B, *state_shape]`` if batched input + :rtype: torch.Tensor + """ + pass + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py new file mode 100644 index 000000000..0a6e77767 --- /dev/null +++ b/tests/test_solver/test_autoregressive_solver.py @@ -0,0 +1,155 @@ +import pytest +import torch + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver + +NUM_TIMESTEPS = 10 +NUM_FEATURES = 3 + + +def _make_series(T=NUM_TIMESTEPS, F=NUM_FEATURES): + torch.manual_seed(42) + y = torch.zeros(T, F) + y[0] = torch.rand(F) + for t in range(T - 1): + y[t + 1] = 0.95 * y[t] + return y + + +@pytest.fixture +def y_data(): + return _make_series() + + +# crate a test Model +class ExactModel(torch.nn.Module): + """ + This model implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x + return next_state + 0.0 * self.dummy_param + + +# Tests start here ============================================== + + +def test_unroll_shape_and_content(y_data): + # unroll_length=4 -> Twin=5 + w = AutoregressiveSolver.unroll( + y_data, unroll_length=4, num_unrolls=2, randomize=False + ) + assert w.shape == (2, 5, NUM_FEATURES) + # deterministic starts: 0 and 1 + assert torch.allclose(w[0], y_data[0:5]) + assert torch.allclose(w[1], y_data[1:6]) + + +def test_decide_starting_indices_edge_cases(y_data): + idx = AutoregressiveSolver.decide_starting_indices( + y_data, unroll_length=3, num_unrolls=None, randomize=False + ) + # T=10, Twin=4 => last_start=6 => 0..6 + assert torch.equal(idx, torch.arange(7)) + + idx_empty = AutoregressiveSolver.decide_starting_indices( + y_data, + unroll_length=NUM_TIMESTEPS + 5, + num_unrolls=None, + randomize=False, + ) + assert idx_empty.numel() == 0 + + +def test_exact_model(y_data): + + windows = AutoregressiveSolver.unroll( + y_data, unroll_length=5, num_unrolls=4, randomize=False + ) + + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = { + "data_condition": DataCondition(input=windows), + } + + conditions_settings = { + "data_condition": {"eps": None, "aggregation_strategy": torch.sum}, + } + solver = AutoregressiveSolver( + problem=Problem(), + conditions_settings=conditions_settings, + model=ExactModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.01), + ) + + loss = solver.loss_data(windows, **conditions_settings["data_condition"]) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + +def test_predict_matches_ground_truth(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + solver = AutoregressiveSolver(problem=Problem(), model=ExactModel()) + + pred = solver.predict(y_data[0], num_steps=NUM_TIMESTEPS - 1) + assert pred.shape == y_data.shape + assert torch.allclose(pred, y_data, atol=1e-6) + + +def test_weighting_strategy_is_finite_and_normalized(): + step_losses = torch.tensor([1.0, 2.0, 3.0]) + w = AutoregressiveSolver.weighting_strategy(step_losses, eps=1.0) + assert torch.isfinite(w).all() + assert torch.isclose(w.sum(), torch.tensor(1.0), atol=1e-6) + + w2 = AutoregressiveSolver.weighting_strategy(step_losses, eps=None) + assert torch.isclose(w2.sum(), torch.tensor(1.0), atol=1e-6) + + +def test_trainer_integration_one_epoch(y_data): + windows = AutoregressiveSolver.unroll( + y_data, unroll_length=5, num_unrolls=None, randomize=False + ) + + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=windows)} + + solver = AutoregressiveSolver( + problem=Problem(), + model=ExactModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=1e-2), + conditions_settings={ + "data": {"eps": None, "aggregation_strategy": torch.sum} + }, + ) + + trainer = Trainer( + solver=solver, + max_epochs=1, + ) + trainer.train() + + # Just check we didn't produce NaNs somewhere + with torch.no_grad(): + loss = solver.loss_data( + windows[:4], eps=None, aggregation_strategy=torch.sum + ) + assert torch.isfinite(loss)