feat(grpo): whitelist input_features for audio-aware multimodal GRPO#260
Open
cliangyu wants to merge 1 commit into
Open
feat(grpo): whitelist input_features for audio-aware multimodal GRPO#260cliangyu wants to merge 1 commit into
cliangyu wants to merge 1 commit into
Conversation
Qwen3-Omni-style models consume audio as log-mel spectrograms shaped
[batch, mel_bins, time] under the key 'input_features'. Today the
GRPO/GFPO generation and collation pipeline recognises vision/video
tensors but silently drops audio:
1. normalize_generation_model_kwargs() whitelists only keys listed in
GENERATION_MODEL_INPUT_KEYS; input_features is missing, so audio
is dropped before reaching model.__call__.
2. The GRPO data collator branches by FLATTENABLE_MULTIMODAL_KEYS
(utils.py) vs prompt-aligned left-pad. input_features routed to
the else-branch gets its trailing (mel-time) dim padded to
max_prompt_length, which corrupts the audio.
3. GROUPED_MULTIMODAL_MODEL_INPUT_KEYS governs per-step flattening
in _flatten_grouped_multimodal_model_value; audio falls through
to the modality-agnostic return, which is the correct shape for
[B, mel_bins, time] already.
Adding 'input_features' to the three key sets makes audio flow through
the same path vision/video already use. The signature-filter backstop
at training_utils.normalize_generation_model_kwargs (filter_kwargs_for_callable)
means models that don't take input_features are unaffected.
This also flips the base_trainer.generate_unified has_model_kwargs
heuristic True for audio, which auto-disables eSurge (vision-only
rollout engine) and routes to compiled generation — the path already
validated for audio via merge_multimodal_embeddings at audio_token_id.
Tested against Qwen3-Omni-30B-A3B-Instruct on LibriSpeech ASR (GRPO
audio→text).
Refs: discussion in feat/audio-keys-grpo branch.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Whitelist
input_featuresfor audio-aware multimodal GRPOSummary
Audio-capable multimodal models such as Qwen3-Omni-30B-A3B accept log-mel spectrograms under the keyword argument
input_features(shape[batch, mel_bins, time]). Today the GRPO/GFPO generation and collation pipeline whitelists vision and video tensors but silently dropsinput_featuresbefore the model forward, because three key sets never list it:easydel.trainers.training_utils.GENERATION_MODEL_INPUT_KEYSeasydel.trainers.training_utils.GROUPED_MULTIMODAL_MODEL_INPUT_KEYSeasydel.trainers.utils.FLATTENABLE_MULTIMODAL_KEYS(seed frozenset; intersected withGENERATION_MODEL_INPUT_KEYSat import time)This PR adds
"input_features"to all three sets. Total diff is +15 / -0.What was broken before
Tracing a Qwen3-Omni audio tensor through the GRPO pipeline:
input_featuresis absent from whitelistsnormalize_generation_model_kwargs(training_utils.py:152)filter_kwargs_for_callable; never reaches the modelGRPODataCollator(utils.py:1549-1592)elsebranch, which pads trailing dim tomax_prompt_lengthvia_maybe_left_pad_prompt_aligned_array. For[mel_bins, time ≈ 1500]this pads mel-time to prompt-token count — catastrophic_flatten_grouped_multimodal_model_value(training_utils.py:203-220)[B, mel_bins, time], but only reachable if the key is grouped_expand_generation_value(infra/mixins/generation.py)jnp.repeat(axis=0)fallthroughbase_trainer.generate_unified:2561-2566inference/esurge/has noinput_features/ audio paths), buthas_model_kwargs=Trueauto-disables eSurge and falls back to compiled generation. Patch flips this flag True for audio.Why this is safe
normalize_generation_model_kwargsfilters the whitelist throughfilter_kwargs_for_callable(training_utils.py:181). Models that do not acceptinput_features(vision-only models, text-only models) will have it dropped at the signature boundary beforemodel.__call__. Over-inclusion is inert._normalize_flattenable_multimodal_array(utils.py:1712) only special-cases vision/video reshapes; audio falls through unchanged, which is correct for already-flat mel tensors._flatten_grouped_multimodal_model_valuereturns audio unchanged for the same reason.PROMPT_ALIGNED_LEFT_PAD_KEYSis not touched — mel-time is not token-aligned.Tested against
Qwen3-Omni-30B-A3B-Instruct on LibriSpeech ASR (GRPO audio→text). Without this patch, the data collator produces tensors with shape
[..., max_prompt_length]instead of[..., mel_time]and training diverges immediately; with this patch the collator routes through the multimodal flatten/stack path and training is stable.Tests
Source-text (AST) tests cover the invariant at the repository level and don't need JAX:
A behaviour test (
pytest.importorskip("easydel")) verifiesnormalize_generation_model_kwargskeepsinput_featureswhen a model callable declares it andcompact_generation_model_kwargssurfaces it forhas_model_kwargsdetection.Out of scope (separate PRs planned)
Qwen3OmniMoe.audio(input_features)atmodeling_qwen3_omni_moe.py:3962is called with noattention_mask, so variable-length audio gets silently contaminated by zero-padded frames. Safe workaround: pad all audio to 30 s (max_source_positions_for_audio=1500). Follow-up will threadfeature_attention_maskthrough._update_model_kwargs_for_generationpopsinputs_embeds,deepstack_visual_embeds,visual_pos_masksbut notinput_features/pixel_values/pixel_values_videos. Audio/vision encoders re-run every decode step — correctness preserved, ~2-10× rollout slowdown.MultiModalManager,BatchMetadata,model_runner,batch_preparer, and the scheduler prefix-cache keys all assume vision-only multimodal. A full port is a multi-week project; the compiled-generation fallback path activated by this PR is sufficient for research-grade audio RL workloads in the meantime.Checklist
input_featurestoGENERATION_MODEL_INPUT_KEYSinput_featurestoGROUPED_MULTIMODAL_MODEL_INPUT_KEYSinput_featuresto the seed frozenset ofFLATTENABLE_MULTIMODAL_KEYSinput_featurestoPROMPT_ALIGNED_LEFT_PAD_KEYS