Conversation
| # reshard_after_forward=False keeps FE parameters unsharded | ||
| # during the multi-step rollout loop. | ||
| # Needed for pushforward trick. | ||
| fully_shard(module, reshard_after_forward=False, **fsdp_kwargs) |
There was a problem hiding this comment.
@sophie-xhonneux : is this maybe related to the problem we are seeing with the EMATeacher where we need to reshard?
| tokens = self.forecast_engine(tokens, step, model_params.rope_coords) | ||
|
|
||
| # Add empty predictions for all streams (vectorized / batched if possible) | ||
| for stream_name in self.stream_names: |
There was a problem hiding this comment.
We can avoid this when we create output with the correct length. I thought we do this anyway, and hence the step argument.
| ) | ||
|
|
||
| if needs_full_prediction: | ||
| tokens = self.forecast_engine(tokens, step, coords=model_params.rope_coords) |
There was a problem hiding this comment.
Remove coords= for consistency
| or step == max(batch.get_output_idxs()) | ||
| ) | ||
|
|
||
| if needs_full_prediction: |
There was a problem hiding this comment.
We already call the forecast engine in l702. Don't we call it twice then in one iteration.
Also, if we set self.forecast_engine to Identity if the number of blocks is 0 then we avoid the condition above.
There was a problem hiding this comment.
-
Sorry the one in l702 must have been introduced due to wrong merge, I will delete it.
-
Many thanks for the heads-up, I will correct this too.
There was a problem hiding this comment.
@clessig just to make sure that I understand, if blocks are not >0 then the forecasting engine turns to None. Should I introduce an identity function instead in this case?
| output = self.predict_decoders(model_params, step, tokens, batch, output) | ||
| # latent predictions (raw and with SSL heads) | ||
| output = self.predict_latent(model_params, step, tokens, batch, output) | ||
| needs_full_prediction = ( |
There was a problem hiding this comment.
If you choose a bit more compact variable names then we can fit this in one line and it's more readable.
There was a problem hiding this comment.
Eg. pushforward instead of pushforward_trick
| needs_full_prediction = ( | ||
| not pushforward_trick | ||
| or not self.training | ||
| or step == max(batch.get_output_idxs()) |
There was a problem hiding this comment.
You use this to determine the number of forecast steps?
There was a problem hiding this comment.
I use this to determine the last step for which we will take gradients.
Description
This PR adds the necessary changes for the pushforward trick.
Issue Number
Closes #1740
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60