Skip to content

[1849][model] Add optional latent Zarr writer#1860

Open
evenmn wants to merge 29 commits intoecmwf:developfrom
metno:feature/latent-zarr-writer
Open

[1849][model] Add optional latent Zarr writer#1860
evenmn wants to merge 29 commits intoecmwf:developfrom
metno:feature/latent-zarr-writer

Conversation

@evenmn
Copy link
Copy Markdown

@evenmn evenmn commented Feb 17, 2026

Description

This PR introduces a writer for the latent vector. To avoid additional config options, I decided to add "latent" as a special case of a stream output, e.g:

output.streams: ["ERA5", "latent"]

The latent output contains the latent vector itself, metadata to map healpix cells to coordinates, as well as extra features which may be useful for explainable AI applications. The combined ERA5 and latent output file takes the form (here for healpix level 5 and 2048 latent channels):

├── 0                                                                                                                                                                                                              
│   ├── ERA5                                                                                                                                                                                                       
│   │   ├── 0                                                                                                                                                                                                      
│   │   │   └── source                                                                                                                                                                                             
│   │   │       ├── coords (40320, 2) float32
│   │   │       ├── data (40320, 2) float32
│   │   │       ├── geoinfo (40320, 0) float32
│   │   │       └── times (40320,) datetime64
│   │   ├── 1
│   │   │   ├── prediction
│   │   │   │   ├── coords (40320, 2) float32
│   │   │   │   ├── data (40320, 1, 1) float32
│   │   │   │   ├── geoinfo (40320, 0) float32
│   │   │   │   └── times (40320,) datetime64
│   │   │   └── target
│   │   │       ├── coords (40320, 2) float32
│   │   │       ├── data (40320, 1) float32
│   │   │       ├── geoinfo (40320, 0) float32
│   │   │       └── times (40320,) datetime64
...
│   └── latent
│          ├── 0                                                                                                                                                                                                      
│          │   ├── coords (12288, 2) float32                                                                                                                                                                          
│          │   ├── geoinfo (12288, 0) float32                                                                                                                                                                         
│          │   ├── latent_state (12288, 2048) float32                                                                                                                                                                 
│          │   ├── latent_state_class_token (0, 2048) float32                                                                                                                                                         
│          │   ├── latent_state_register_tokens (0, 2048) float32                                                                                                                                                     
│          │   └── times (12288,) datetime64 
...

Going further, we can visualize the latent vector:

unzip validation_chkpt00000_rank0000.zip -d test.zarr

import zarr
import cartopy.crs as ccrs
import matplotlib.pyplot as plt

z = zarr.open("test.zarr", mode="r", zarr_version=3)

coords = z["0/latent/0/coords"][:]
latent = z["0/latent/0/latent_state"][:]

lat = coords[:,0]
lon = coords[:,1]
values = latent[:,0]

fig = plt.figure(figsize=(12,6))
ax = plt.axes(projection=ccrs.PlateCarree())

sc = ax.scatter(lon, lat, c=values, s=10, transform=ccrs.PlateCarree())

ax.coastlines()
plt.colorbar(sc, label="latent feature 0")

plt.savefig("latent_feature_0.png", dpi=300)

Healpix level 4:
latent_feature_0

Healpix level 5:
latent_time

Issue Number

Closes #1849

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

…ams: ['..., latent'])

Signed-off-by: evenmn <evenmn@mn.uio.no>
@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Feb 18, 2026

@evenmn : can you please make sure the linter and unit tests pass.

@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Feb 18, 2026

@grassesi : can you have a look at the changes in io.py. How far are you with refactoring the output writing?

evenmn and others added 4 commits February 19, 2026 08:34
Signed-off-by: evenmn <evenmn@mn.uio.no>
Signed-off-by: evenmn <evenmn@mn.uio.no>
Copy link
Copy Markdown
Contributor

@grassesi grassesi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I like the idea of treating latent output just as another stream. One thing that needs to be addressed is that there is some masking logic that prevents the latent stream from trying to be processed during evaluation (esp. the to_xarray call will not work). Otherwise just some stylistic remarks.
I saw in the Issue that the latents should be eventually exposed via an JSON API? Would it make sense to already implement this and not try to piggyback on ZarrIO?

Comment on lines +618 to +625
# additionally yield latent output items if a latent stream name was provided
if self.latent_stream_name is not None and self.latents:
for s, fo_s in itertools.product(self.samples, self.forecast_steps):
key = ItemKey(int(s), int(fo_s), self.latent_stream_name)
latent_item = self._make_latent_item(key)
if latent_item is not None:
yield latent_item

Copy link
Copy Markdown
Contributor

@grassesi grassesi Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to wrap this logic into the above for iteration: having two yielding for loops works but is confusing at best. I see two alternatives:

  1. Mix the writing of latent items into the normal loop. Something like
for s, fo_s, fi_s in itertools.product(
        self.samples, self.forecast_steps, self.streams.keys()
):
    key = ItemKey(int(s), int(fo_s), fi_s)
    if fi_s == LATENT_STREAM:
        latent_item = self._make_latent_item(key)
        if latent_item is not None:
            yield latent_item
    else:
        yield self.extract(ItemKey(int(s), int(fo_s), fi_s))
  1. have the writing of latent items in a separate method:
def latent_items(self):
    if self.latents:
        for s, fo_s in itertools.product(self.samples, self.forecast_steps):
            key = ItemKey(int(s), int(fo_s), LATENT_STREAM)
            latent_item = self._make_latent_item(key)
            if latent_item is not None:
                yield latent_item

...
with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio:
    for subset in data.items():
        zio.write_zarr(subset)
    for latent in data.latent_items():
        zio.write_zarr(latent)

Option 2. is maybe a bit more clearer and more equivalent to the current solution and also provides more flexibility for the future. But Option 1 would be also fine with me.

Comment thread src/weathergen/utils/validation_io.py Outdated
Comment thread src/weathergen/utils/validation_io.py
Comment on lines +590 to +591
# optional name to use for latent pseudo-stream when yielding latent items
latent_stream_name: str | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a named constant for this.

@github-project-automation github-project-automation bot moved this to In Progress in WeatherGen-dev Feb 19, 2026
@github-actions github-actions bot added data Anything related to the datasets used in the project infra Issues related to infrastructure model Related to model training or definition (not generic infra) labels Feb 19, 2026
@evenmn
Copy link
Copy Markdown
Author

evenmn commented Feb 19, 2026

Overall I like the idea of treating latent output just as another stream. One thing that needs to be addressed is that there is some masking logic that prevents the latent stream from trying to be processed during evaluation (esp. the to_xarray call will not work). Otherwise just some stylistic remarks. I saw in the Issue that the latents should be eventually exposed via an JSON API? Would it make sense to already implement this and not try to piggyback on ZarrIO?

Thanks for your feedback. Exposing the latent space via an JSON API is useful when running the model operationally. However, I still think we should export the latent state as a Zarr file, since this is useful for other applications, for instance explanable AI.

@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Feb 19, 2026

Overall I like the idea of treating latent output just as another stream. One thing that needs to be addressed is that there is some masking logic that prevents the latent stream from trying to be processed during evaluation (esp. the to_xarray call will not work). Otherwise just some stylistic remarks. I saw in the Issue that the latents should be eventually exposed via an JSON API? Would it make sense to already implement this and not try to piggyback on ZarrIO?

Thanks for your feedback. Exposing the latent space via an JSON API is useful when running the model operationally. However, I still think we should export the latent state as a Zarr file, since this is useful for other applications, for instance explanable AI.

Yes, json API is a separate issue and this PR should address writing the latent space to disc as zarr.

Copy link
Copy Markdown
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. But validation_io will be refactored and it should be discussed how to best do the latent output going forward.

Comment thread src/weathergen/utils/validation_io.py Outdated
output_streams = {}
for name in output_stream_names:
if name == "latent":
latent_stream_name = name
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the logic here. Wouldn't it be enough to have

if "latent" in output_stream_names

in l 158? Do we expect to have have multiple latent states?

Comment thread src/weathergen/utils/validation_io.py Outdated
per_sample = {}
for lname, lval in latent_pred.items():
if isinstance(lval, LatentState):
for field in ("z_pre_norm", "patch_tokens", "register_tokens", "class_token"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latent state that should be relevant for the output are the patch_tokens. These are used for the decoder. To be fully future proof we could have an argument which part of LatentState is written although it might be over-engineering.

Comment thread src/weathergen/utils/validation_io.py Outdated

# collect latent outputs per forecast step and per sample (optional)
latents_all = []
if latent_stream_name is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go to a separate function.

Signed-off-by: evenmn <evenmn@mn.uio.no>
Signed-off-by: evenmn <evenmn@mn.uio.no>
Comment on lines +591 to +594
latent_stream_name: str | None = None
latent_stream_name: str | None = LATENT_STREAM
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove: since you are always using the default here anyway it makes no difference if self.latent_stream_name or LATENT_STREAM is used. But using using latent_stream_name clutters up the namespace/interface of OutputBatchData

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.latent_stream_name is currently an alias to LATENT_STREAM which is never None. So please just use if self.latents (This should not disregard my previous comment on these lines.)

Comment thread src/weathergen/utils/validation_io.py Outdated
stream_names = [stream.name for stream in cf.streams]
# include known pseudo-stream names (e.g. latent) so they are treated as known
if io.LATENT_STREAM not in stream_names:
stream_names.append(io.LATENT_STREAM)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove and just use:

if io.LATENT_STREAM in output_stream_names:
    output_streams[io.LATENT_STREAM] = None

in ll. 136

Comment thread src/weathergen/utils/validation_io.py Outdated
for name in output_stream_names:
if name == "latent":
if name == io.LATENT_STREAM:
latent_stream_name = name
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this if clause, instead implement the suggestion I commented above

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use io.LATENT_STREAM here.

@grassesi
Copy link
Copy Markdown
Contributor

Overall looks good. But validation_io will be refactored and it should be discussed how to best do the latent output going forward.

The choosen approach should translate relatively well into the refactored version.

…get_latent_output'

Signed-off-by: evenmn <evenmn@mn.uio.no>
@evenmn
Copy link
Copy Markdown
Author

evenmn commented Feb 23, 2026

Thanks to both of you for the feedback, it truly improved this PR. I believe I have incorporated the suggested changes in my latest commit, but I still need to test the implementation before it gets merged in.

@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Feb 23, 2026

Thanks to both of you for the feedback, it truly improved this PR. I believe I have incorporated the suggested changes in my latest commit, but I still need to test the implementation before it gets merged in.

Can you please test it as far as you can, and then I have a final look.

@tjhunter tjhunter added the app label Feb 26, 2026
@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Mar 5, 2026

@tjhunter : can we open a separate PR to fix the test.

@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Mar 5, 2026

@evenmn : for developing/debugging this, it's probably best to use training with

# validation config; full validation config is merge of training and validation config
validation_config: 

  samples_per_mini_epoch: 256
  shuffle: False

  start_date: 2023-10-01T00:00
  end_date: 2023-12-31T00:00
  
  # whether to track the exponential moving average of weights for validation
  validate_with_ema: 
    enabled : True
    ema_ramp_up_ratio: 0.09
    ema_halflife_in_thousands: 1e-3

  # parameters for validation samples that are written to disk
  output : {
    # number of samples that are written
    num_samples: 8,
    # write samples in normalized model space
    normalized_samples: False,
    # output streams to write; default all
    streams: null,
    }

  # run validation before training starts (mainly for model development)
  validate_before_training: True

rather than the integration test (which also do evaluation, inference etc). This way the io is triggered right at the beginning.

evenmn added 3 commits March 11, 2026 11:55
Signed-off-by: evenmn <evenmn@fys.uio.no>
…coords accessible from validation_io.py

Signed-off-by: evenmn <evenmn@fys.uio.no>
… writes metadata correctly, was added

Signed-off-by: evenmn <evenmn@fys.uio.no>
@evenmn evenmn marked this pull request as ready for review March 11, 2026 11:19
@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Mar 15, 2026

@evenmn : Trainer should not get ModelParams passed. This creates a dependency that shouldn't be there to ensure proper encapsulation. Computing the healpix coords is cheap so if it is needed we can just redo it in validation_io (but using the same function). The cleanest solution would arguably be to attach just the coords to latent output. If it doesn't incur a performance penalty then that's the way to go.

@evenmn
Copy link
Copy Markdown
Author

evenmn commented Mar 15, 2026

@evenmn : Trainer should not get ModelParams passed. This creates a dependency that shouldn't be there to ensure proper encapsulation. Computing the healpix coords is cheap so if it is needed we can just redo it in validation_io (but using the same function). The cleanest solution would arguably be to attach just the coords to latent output. If it doesn't incur a performance penalty then that's the way to go.

@clessig: I found it natural to make the healpix coords part of the ModelParams, but I agree that it's not ideal passing it to the trainer. I will move it to validation_io and see if it degrades the performance (which I doubt it will)

@grassesi grassesi mentioned this pull request Mar 16, 2026
4 tasks
…y. This avoids passing ModelParams to the Trainer

Signed-off-by: evenmn <evenmn@fys.uio.no>
@evenmn
Copy link
Copy Markdown
Author

evenmn commented Mar 16, 2026

@clessig: I just moved the healpix_coords calculations to validation_io.py

evenmn added 4 commits March 26, 2026 12:59
Signed-off-by: evenmn <evenmn@fys.uio.no>
Signed-off-by: evenmn <evenmn@fys.uio.no>
Signed-off-by: evenmn <evenmn@fys.uio.no>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

app data Anything related to the datasets used in the project infra Issues related to infrastructure model Related to model training or definition (not generic infra)

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

Add support for exporting latent space from WeatherGenerator

5 participants