Skip to content

feat(grpo): whitelist input_features for audio-aware multimodal GRPO#260

Open
cliangyu wants to merge 1 commit into
erfanzar:mainfrom
cliangyu:feat/audio-keys-grpo
Open

feat(grpo): whitelist input_features for audio-aware multimodal GRPO#260
cliangyu wants to merge 1 commit into
erfanzar:mainfrom
cliangyu:feat/audio-keys-grpo

Conversation

@cliangyu

Copy link
Copy Markdown

Whitelist input_features for audio-aware multimodal GRPO

Summary

Audio-capable multimodal models such as Qwen3-Omni-30B-A3B accept log-mel spectrograms under the keyword argument input_features (shape [batch, mel_bins, time]). Today the GRPO/GFPO generation and collation pipeline whitelists vision and video tensors but silently drops input_features before the model forward, because three key sets never list it:

  1. easydel.trainers.training_utils.GENERATION_MODEL_INPUT_KEYS
  2. easydel.trainers.training_utils.GROUPED_MULTIMODAL_MODEL_INPUT_KEYS
  3. easydel.trainers.utils.FLATTENABLE_MULTIMODAL_KEYS (seed frozenset; intersected with GENERATION_MODEL_INPUT_KEYS at import time)

This PR adds "input_features" to all three sets. Total diff is +15 / -0.

What was broken before

Tracing a Qwen3-Omni audio tensor through the GRPO pipeline:

Stage Code path Symptom when input_features is absent from whitelists
Prep normalize_generation_model_kwargs (training_utils.py:152) Tensor is dropped before filter_kwargs_for_callable; never reaches the model
Collator GRPODataCollator (utils.py:1549-1592) Key falls into the else branch, which pads trailing dim to max_prompt_length via _maybe_left_pad_prompt_aligned_array. For [mel_bins, time ≈ 1500] this pads mel-time to prompt-token count — catastrophic
Step flatten _flatten_grouped_multimodal_model_value (training_utils.py:203-220) Falls through to return-unchanged — correct for [B, mel_bins, time], but only reachable if the key is grouped
Rollout fwd _expand_generation_value (infra/mixins/generation.py) Already handles audio correctly via the generic jnp.repeat(axis=0) fallthrough
eSurge base_trainer.generate_unified:2561-2566 eSurge engine is vision-only (inference/esurge/ has no input_features / audio paths), but has_model_kwargs=True auto-disables eSurge and falls back to compiled generation. Patch flips this flag True for audio.

Why this is safe

  • normalize_generation_model_kwargs filters the whitelist through filter_kwargs_for_callable (training_utils.py:181). Models that do not accept input_features (vision-only models, text-only models) will have it dropped at the signature boundary before model.__call__. Over-inclusion is inert.
  • _normalize_flattenable_multimodal_array (utils.py:1712) only special-cases vision/video reshapes; audio falls through unchanged, which is correct for already-flat mel tensors.
  • _flatten_grouped_multimodal_model_value returns audio unchanged for the same reason.
  • PROMPT_ALIGNED_LEFT_PAD_KEYS is not touched — mel-time is not token-aligned.

Tested against

Qwen3-Omni-30B-A3B-Instruct on LibriSpeech ASR (GRPO audio→text). Without this patch, the data collator produces tensors with shape [..., max_prompt_length] instead of [..., mel_time] and training diverges immediately; with this patch the collator routes through the multimodal flatten/stack path and training is stable.

Tests

Source-text (AST) tests cover the invariant at the repository level and don't need JAX:

# audio_video_rl/tests/test_audio_keys_patch.py
def test_input_features_in_generation_model_input_keys() -> None:
    tree = ast.parse(TRAINING_UTILS.read_text())
    value = _find_assign(tree, "GENERATION_MODEL_INPUT_KEYS")
    assert "input_features" in _literal_strings_in(value)

A behaviour test (pytest.importorskip("easydel")) verifies normalize_generation_model_kwargs keeps input_features when a model callable declares it and compact_generation_model_kwargs surfaces it for has_model_kwargs detection.

Out of scope (separate PRs planned)

  1. Audio attention mask: Qwen3OmniMoe.audio(input_features) at modeling_qwen3_omni_moe.py:3962 is called with no attention_mask, so variable-length audio gets silently contaminated by zero-padded frames. Safe workaround: pad all audio to 30 s (max_source_positions_for_audio=1500). Follow-up will thread feature_attention_mask through.
  2. Decode-step media pop: _update_model_kwargs_for_generation pops inputs_embeds, deepstack_visual_embeds, visual_pos_masks but not input_features / pixel_values / pixel_values_videos. Audio/vision encoders re-run every decode step — correctness preserved, ~2-10× rollout slowdown.
  3. eSurge audio port: MultiModalManager, BatchMetadata, model_runner, batch_preparer, and the scheduler prefix-cache keys all assume vision-only multimodal. A full port is a multi-week project; the compiled-generation fallback path activated by this PR is sufficient for research-grade audio RL workloads in the meantime.

Checklist

  • Adds input_features to GENERATION_MODEL_INPUT_KEYS
  • Adds input_features to GROUPED_MULTIMODAL_MODEL_INPUT_KEYS
  • Adds input_features to the seed frozenset of FLATTENABLE_MULTIMODAL_KEYS
  • Does not add input_features to PROMPT_ALIGNED_LEFT_PAD_KEYS
  • No behavioural change for non-audio models (signature filter)
  • Tests (source-AST invariants + TPU-gated behaviour test)

Qwen3-Omni-style models consume audio as log-mel spectrograms shaped
[batch, mel_bins, time] under the key 'input_features'. Today the
GRPO/GFPO generation and collation pipeline recognises vision/video
tensors but silently drops audio:

  1. normalize_generation_model_kwargs() whitelists only keys listed in
     GENERATION_MODEL_INPUT_KEYS; input_features is missing, so audio
     is dropped before reaching model.__call__.

  2. The GRPO data collator branches by FLATTENABLE_MULTIMODAL_KEYS
     (utils.py) vs prompt-aligned left-pad. input_features routed to
     the else-branch gets its trailing (mel-time) dim padded to
     max_prompt_length, which corrupts the audio.

  3. GROUPED_MULTIMODAL_MODEL_INPUT_KEYS governs per-step flattening
     in _flatten_grouped_multimodal_model_value; audio falls through
     to the modality-agnostic return, which is the correct shape for
     [B, mel_bins, time] already.

Adding 'input_features' to the three key sets makes audio flow through
the same path vision/video already use. The signature-filter backstop
at training_utils.normalize_generation_model_kwargs (filter_kwargs_for_callable)
means models that don't take input_features are unaffected.

This also flips the base_trainer.generate_unified has_model_kwargs
heuristic True for audio, which auto-disables eSurge (vision-only
rollout engine) and routes to compiled generation — the path already
validated for audio via merge_multimodal_embeddings at audio_token_id.

Tested against Qwen3-Omni-30B-A3B-Instruct on LibriSpeech ASR (GRPO
audio→text).

Refs: discussion in feat/audio-keys-grpo branch.
@cliangyu cliangyu marked this pull request as ready for review April 23, 2026 23:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant