Skip to content

Sophiex/dev/feat qk rmsnorm#2033

Merged
sophie-xhonneux merged 23 commits intodevelopfrom
sophiex/dev/feat-qk-rmsnorm
Apr 10, 2026
Merged

Sophiex/dev/feat qk rmsnorm#2033
sophie-xhonneux merged 23 commits intodevelopfrom
sophiex/dev/feat-qk-rmsnorm

Conversation

@sophie-xhonneux
Copy link
Copy Markdown
Contributor

Description

Improves performance see https://gitlab.jsc.fz-juelich.de/hedgedoc/OgGOfs2-RVOZcEEjB0108w?view

Issue Number

Closes #2032

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@github-actions github-actions bot added eval anything related to the model evaluation pipeline infra Issues related to infrastructure model Related to model training or definition (not generic infra) labels Mar 10, 2026
return config


def _check_qk_norm_type(config: Config) -> Config:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This backfilling leads to problems elsewhere. We can just use cf.get( ...)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was necessary I think for loading the teacher model somehow. I forget the details

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove. This should not be added here. _check_logging above let to all kinds of problems lately. Using cf.get("config.qk_norm_type", "LayerNorm") is much more robust and useful

Comment thread packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py Outdated
Comment thread config/config_ema_warm_start.yml Outdated
Comment thread src/weathergen/model/encoder.py
return tokens_global_c


class Local2GlobalSumEngine(torch.nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate PR

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the PR with this change is merged first

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which PR is this?

Comment thread src/weathergen/model/engines.py
Copy link
Copy Markdown
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, PR looks good (has been rebased to latest develop). Can be merged with two changes:

  • Remove the sum aggregation engine (or merge a clean PR with this before)
  • Move the fix(?) / change to plot_utils.py to a separate PR--also see my comment there

@shmh40 : can you approve and merge

return config


def _check_qk_norm_type(config: Config) -> Config:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove. This should not be added here. _check_logging above let to all kinds of problems lately. Using cf.get("config.qk_norm_type", "LayerNorm") is much more robust and useful

return tokens_global_c


class Local2GlobalSumEngine(torch.nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which PR is this?

Comment thread packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py Outdated
@shmh40 shmh40 self-requested a review April 1, 2026 13:59
Copy link
Copy Markdown
Contributor

@shmh40 shmh40 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good with me.

@sophie-xhonneux sophie-xhonneux merged commit f42e906 into develop Apr 10, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

eval anything related to the model evaluation pipeline infra Issues related to infrastructure model:pretrain model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Test performance with learnable RMSNorm in attention

3 participants