Skip to content

Add changes for pushforward trick#1997

Open
SavvasMel wants to merge 6 commits intoecmwf:developfrom
SavvasMel:SavvasMel/develop/pushf_trick
Open

Add changes for pushforward trick#1997
SavvasMel wants to merge 6 commits intoecmwf:developfrom
SavvasMel:SavvasMel/develop/pushf_trick

Conversation

@SavvasMel
Copy link
Contributor

Description

This PR adds the necessary changes for the pushforward trick.

Issue Number

Closes #1740

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@github-actions github-actions bot added model Related to model training or definition (not generic infra) science Scientific questions labels Mar 9, 2026
@SavvasMel SavvasMel requested a review from clessig March 12, 2026 11:57
# reshard_after_forward=False keeps FE parameters unsharded
# during the multi-step rollout loop.
# Needed for pushforward trick.
fully_shard(module, reshard_after_forward=False, **fsdp_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sophie-xhonneux : is this maybe related to the problem we are seeing with the EMATeacher where we need to reshard?

tokens = self.forecast_engine(tokens, step, model_params.rope_coords)

# Add empty predictions for all streams (vectorized / batched if possible)
for stream_name in self.stream_names:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid this when we create output with the correct length. I thought we do this anyway, and hence the step argument.

)

if needs_full_prediction:
tokens = self.forecast_engine(tokens, step, coords=model_params.rope_coords)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove coords= for consistency

or step == max(batch.get_output_idxs())
)

if needs_full_prediction:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already call the forecast engine in l702. Don't we call it twice then in one iteration.

Also, if we set self.forecast_engine to Identity if the number of blocks is 0 then we avoid the condition above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Sorry the one in l702 must have been introduced due to wrong merge, I will delete it.

  2. Many thanks for the heads-up, I will correct this too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clessig just to make sure that I understand, if blocks are not >0 then the forecasting engine turns to None. Should I introduce an identity function instead in this case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly

output = self.predict_decoders(model_params, step, tokens, batch, output)
# latent predictions (raw and with SSL heads)
output = self.predict_latent(model_params, step, tokens, batch, output)
needs_full_prediction = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you choose a bit more compact variable names then we can fit this in one line and it's more readable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eg. pushforward instead of pushforward_trick

needs_full_prediction = (
not pushforward_trick
or not self.training
or step == max(batch.get_output_idxs())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You use this to determine the number of forecast steps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use this to determine the last step for which we will take gradients.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra) science Scientific questions

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Run pushforward trick experiments

2 participants