Chatterbox tts#1976
Conversation
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
Greptile SummaryThis PR introduces
Confidence Score: 2/5Not safe to merge as-is: inference failures from non-OSError/RuntimeError exceptions will abort the entire batch rather than being isolated per-turn, and cross-run cache hits produce output tasks where the reference_voice metadata does not match the audio actually on disk. Two defects affect correctness on the main code path. First, _generate_turn_audio swallows only OSError and RuntimeError, so any other exception from the Chatterbox library escapes both the inner handler and the outer except OSError in process_batch, terminating the batch loop mid-flight. Second, the deterministic output filename encodes only conversation_id + speaker + text, not which reference voice was selected; a fresh run may assign a different reference and return cached audio with incorrect reference_voice metadata. nemo_curator/stages/audio/tts/chatterbox_tts.py — both defects are concentrated here, in _generate_turn_audio (exception handling) and process_batch (cache-hit metadata path). Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant Stage as ChatterboxTTSStage
participant Model as ChatterboxTTS / Multilingual
participant RefAudio as Reference Dataset
participant Disk as Output Dir / Temp Dir
Caller->>Stage: setup()
Stage->>Disk: makedirs(output_audio_dir)
Stage->>Disk: "mkdtemp(chatterbox_ref_*)"
Stage->>Model: from_pretrained(device)
Stage->>RefAudio: "glob wavs/*/*.wav or */*/*.flac"
Stage-->>Caller: ready
Caller->>Stage: process_batch(tasks)
loop for each AudioTask
Stage->>Stage: _assign_reference(speaker, conv_id)
alt new (conv_id, speaker) pair
Stage->>RefAudio: pick reference file
alt wavs layout + RTTM exists
Stage->>Disk: strip silences → temp_dir/ref.wav
else MLS layout
Stage->>Disk: concatenate segments → temp_dir/ref_spk.wav
end
end
Stage->>Stage: _output_filename(conv_id, speaker, text)
alt file already exists on disk
Stage->>Disk: sf.read(audio_path)
else
Stage->>Model: generate(text, audio_prompt_path, ...)
Model-->>Stage: wav tensor
Stage->>Stage: _normalize_audio(wav)
Stage->>Disk: sf.write(audio_path, audio_data)
end
Stage-->>Caller: AudioTask + audio_filepath, duration, reference_voice
end
Caller->>Stage: teardown()
Stage->>Disk: shutil.rmtree(temp_dir)
Reviews (6): Last reviewed commit: "change chatterbox tests structure" | Re-trigger Greptile |
| if language is not None and language.lower() not in SUPPORTED_LANGUAGES: | ||
| raise ValueError( | ||
| f"Unsupported language '{language}'. " | ||
| f"Supported: {', '.join(sorted(SUPPORTED_LANGUAGES))}" | ||
| ) |
There was a problem hiding this comment.
The language code is validated with
.lower() but stored as-is and later passed directly to the model as language_id. If a caller passes "RU" or "FR", it clears the SUPPORTED_LANGUAGES check (because "ru" is in the set), but the raw uppercase string is forwarded to ChatterboxMultilingualTTS.generate. The Chatterbox API expects lowercase ISO 639-1 codes, so inference would either fail or silently produce wrong-language output.
| if language is not None and language.lower() not in SUPPORTED_LANGUAGES: | |
| raise ValueError( | |
| f"Unsupported language '{language}'. " | |
| f"Supported: {', '.join(sorted(SUPPORTED_LANGUAGES))}" | |
| ) | |
| if language is not None and language.lower() not in SUPPORTED_LANGUAGES: | |
| raise ValueError( | |
| f"Unsupported language '{language}'. " | |
| f"Supported: {', '.join(sorted(SUPPORTED_LANGUAGES))}" | |
| ) | |
| if language is not None: | |
| language = language.lower() |
| out_path = os.path.join( | ||
| self.temp_dir, os.path.basename(audio_filepath) | ||
| ) |
There was a problem hiding this comment.
Both
_get_reference_audio_wavs and _get_reference_audio_mls write their RTTM-processed / concatenated output into self.temp_dir using os.path.basename(audio_filepath). When different dialogs contain a speaker file with the same name (e.g. dialog001/Alice.wav and dialog002/Alice.wav), the second write silently overwrites the first temp file. Any speaker already assigned temp_dir/Alice.wav then synthesises audio with the wrong voice without any warning.
| out_path = os.path.join( | |
| self.temp_dir, os.path.basename(audio_filepath) | |
| ) | |
| unique_name = hashlib.md5(audio_filepath.encode()).hexdigest()[:8] + "_" + os.path.basename(audio_filepath) | |
| out_path = os.path.join(self.temp_dir, unique_name) |
| @staticmethod | ||
| def _output_filename(conversation_id: str, speaker: str, text: str) -> str: | ||
| """Deterministic filename: ``{conv_id_short}_{speaker}_{text_hash}.wav``.""" | ||
| conv_short = conversation_id[:12] if len(conversation_id) > 12 else conversation_id | ||
| text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10] | ||
| return f"{conv_short}_{speaker}_{text_hash}.wav" |
There was a problem hiding this comment.
Truncating the conversation ID to 12 characters means two conversations whose IDs share a 12-character prefix (common with structured IDs such as
session1_conv001 / session1_conv002) generate the same filename for the same speaker and text. On a subsequent run the cached file from the first conversation is reused for the second even though a different reference voice may have been assigned, producing a silent audio/metadata mismatch.
| @staticmethod | |
| def _output_filename(conversation_id: str, speaker: str, text: str) -> str: | |
| """Deterministic filename: ``{conv_id_short}_{speaker}_{text_hash}.wav``.""" | |
| conv_short = conversation_id[:12] if len(conversation_id) > 12 else conversation_id | |
| text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10] | |
| return f"{conv_short}_{speaker}_{text_hash}.wav" | |
| @staticmethod | |
| def _output_filename(conversation_id: str, speaker: str, text: str) -> str: | |
| """Deterministic filename: ``{conv_id_hash}_{speaker}_{text_hash}.wav``.""" | |
| conv_hash = hashlib.md5(conversation_id.encode("utf-8")).hexdigest()[:12] | |
| text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10] | |
| return f"{conv_hash}_{speaker}_{text_hash}.wav" |
| reference_wav = self._assign_reference(speaker, conversation_id) | ||
|
|
||
| filename = self._output_filename(conversation_id, speaker, text) | ||
| audio_path = os.path.join(self.output_audio_dir, filename) | ||
|
|
||
| if os.path.exists(audio_path): | ||
| audio_data, _ = sf.read(audio_path) | ||
| else: | ||
| audio_data = self._generate_turn_audio( | ||
| text, reference_wav, conversation_id | ||
| ) | ||
| sf.write(audio_path, audio_data, self.sample_rate) | ||
|
|
||
| duration = len(audio_data) / self.sample_rate | ||
|
|
||
| out_data = dict(data) | ||
| out_data["audio_filepath"] = audio_path | ||
| out_data["duration"] = duration | ||
| out_data["reference_voice"] = Path(reference_wav).parent.name |
There was a problem hiding this comment.
Path(reference_wav).parent.name returns the temp-directory name (e.g. chatterbox_ref_abc123) whenever the reference has been RTTM-processed or comes from the MLS layout, because both code paths write to self.temp_dir/<filename>. Only the raw wavs path (no RTTM) has a meaningful parent (the dialog ID). The emitted reference_voice value should be the MLS speaker ID or the dialog/speaker tag, not an ephemeral temp-dir name.
| reference_wav = self._assign_reference(speaker, conversation_id) | |
| filename = self._output_filename(conversation_id, speaker, text) | |
| audio_path = os.path.join(self.output_audio_dir, filename) | |
| if os.path.exists(audio_path): | |
| audio_data, _ = sf.read(audio_path) | |
| else: | |
| audio_data = self._generate_turn_audio( | |
| text, reference_wav, conversation_id | |
| ) | |
| sf.write(audio_path, audio_data, self.sample_rate) | |
| duration = len(audio_data) / self.sample_rate | |
| out_data = dict(data) | |
| out_data["audio_filepath"] = audio_path | |
| out_data["duration"] = duration | |
| out_data["reference_voice"] = Path(reference_wav).parent.name | |
| reference_wav, ref_id = self._assign_reference(speaker, conversation_id) | |
| filename = self._output_filename(conversation_id, speaker, text) | |
| audio_path = os.path.join(self.output_audio_dir, filename) | |
| if os.path.exists(audio_path): | |
| audio_data, _ = sf.read(audio_path) | |
| else: | |
| audio_data = self._generate_turn_audio( | |
| text, reference_wav, conversation_id | |
| ) | |
| sf.write(audio_path, audio_data, self.sample_rate) | |
| duration = len(audio_data) / self.sample_rate | |
| out_data = dict(data) | |
| out_data["audio_filepath"] = audio_path | |
| out_data["duration"] = duration | |
| out_data["reference_voice"] = ref_id |
| top_p: float = 1.0, | ||
| normalize_audio: bool = True, | ||
| normalize_level: float = -20.0, | ||
| **kwargs, |
| text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10] | ||
| return f"{conv_short}_{speaker}_{text_hash}.wav" | ||
|
|
||
| def _ensure_ready(self) -> None: |
| if not tasks: | ||
| return [] | ||
|
|
||
| self._ensure_ready() |
There was a problem hiding this comment.
Remove. We should never call setup in process_batch/process.
| return [] | ||
|
|
||
| self._ensure_ready() | ||
| os.makedirs(self.output_audio_dir, exist_ok=True) |
There was a problem hiding this comment.
Should this be in setup?
| if self.normalize_audio: | ||
| wav = self._normalize_audio(wav) | ||
|
|
||
| return wav.squeeze(0).numpy() |
There was a problem hiding this comment.
Missing
.cpu() — CUDA inference always silently fails
wav.squeeze(0).numpy() is called directly on the tensor returned by self.model.generate(). The default device is "cuda", which means the model lives on GPU and its output is a CUDA tensor. Calling .numpy() on a CUDA tensor raises RuntimeError: can't convert a given torch.Tensor to numpy, which is swallowed by the surrounding except Exception block, so every single GPU inference call silently falls through and returns 2 seconds of zeros. The tests never catch this because _build_stage hard-codes device="cpu".
| return wav.squeeze(0).numpy() | |
| return wav.squeeze(0).cpu().numpy() |
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
| if os.path.exists(audio_path): | ||
| audio_data, _ = sf.read(audio_path) | ||
| else: | ||
| audio_data = self._generate_turn_audio( | ||
| text, reference_wav, conversation_id | ||
| ) | ||
| sf.write(audio_path, audio_data, self.sample_rate) | ||
|
|
||
| duration = len(audio_data) / self.sample_rate | ||
|
|
||
| out_data = dict(data) | ||
| out_data["audio_filepath"] = audio_path | ||
| out_data["duration"] = duration | ||
| out_data["reference_voice"] = ref_id | ||
|
|
||
| output_tasks.append( | ||
| AudioTask( | ||
| data=out_data, | ||
| task_id=task.task_id, | ||
| dataset_name=task.dataset_name, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
The PR's stated design goal is "graceful failure handling: TTS errors produce a silence placeholder instead of crashing the pipeline," but only
_generate_turn_audio has a try/except. The sf.write call (and the preceding os.path.exists + sf.read for cache hits) sits outside any guard. Any OSError here — disk full, wrong permissions, or a speaker name that contains / or \ generating a path whose parent directory doesn't exist — will raise an unhandled exception and abort the entire remaining batch. Wrapping the file I/O block in its own try/except keeps the one-turn failure contained and matches the silence-fallback contract already in place for inference.
| if os.path.exists(audio_path): | |
| audio_data, _ = sf.read(audio_path) | |
| else: | |
| audio_data = self._generate_turn_audio( | |
| text, reference_wav, conversation_id | |
| ) | |
| sf.write(audio_path, audio_data, self.sample_rate) | |
| duration = len(audio_data) / self.sample_rate | |
| out_data = dict(data) | |
| out_data["audio_filepath"] = audio_path | |
| out_data["duration"] = duration | |
| out_data["reference_voice"] = ref_id | |
| output_tasks.append( | |
| AudioTask( | |
| data=out_data, | |
| task_id=task.task_id, | |
| dataset_name=task.dataset_name, | |
| ) | |
| ) | |
| try: | |
| if os.path.exists(audio_path): | |
| audio_data, _ = sf.read(audio_path) | |
| else: | |
| audio_data = self._generate_turn_audio( | |
| text, reference_wav, conversation_id | |
| ) | |
| sf.write(audio_path, audio_data, self.sample_rate) | |
| except Exception as e: | |
| logger.error(f"File I/O failed for task {task.task_id}: {e}") | |
| output_tasks.append(task) | |
| continue | |
| duration = len(audio_data) / self.sample_rate | |
| out_data = dict(data) | |
| out_data["audio_filepath"] = audio_path | |
| out_data["duration"] = duration | |
| out_data["reference_voice"] = ref_id | |
| output_tasks.append( | |
| AudioTask( | |
| data=out_data, | |
| task_id=task.task_id, | |
| dataset_name=task.dataset_name, | |
| ) | |
| ) |
| if not chunks: | ||
| return files[0], chosen | ||
|
|
||
| concatenated = torch.cat(chunks, dim=1) | ||
| out_path = os.path.join(self.temp_dir, f"ref_{chosen}.wav") | ||
| ta.save(out_path, concatenated, last_sr) | ||
| return out_path, chosen |
There was a problem hiding this comment.
The
torch.cat call and the ta.save that follows sit outside the per-file try/except. If any two FLAC files in chunks have a different number of channels (e.g., one mono, one stereo), torch.cat(chunks, dim=1) raises a RuntimeError that propagates all the way out of process_batch uncaught, aborting the batch entirely. Wrapping the concatenation and save in their own guard — and falling back to the raw first file, matching the existing empty-chunks fallback — keeps the failure isolated to the one speaker assignment.
| if not chunks: | |
| return files[0], chosen | |
| concatenated = torch.cat(chunks, dim=1) | |
| out_path = os.path.join(self.temp_dir, f"ref_{chosen}.wav") | |
| ta.save(out_path, concatenated, last_sr) | |
| return out_path, chosen | |
| if not chunks: | |
| return files[0], chosen | |
| try: | |
| concatenated = torch.cat(chunks, dim=1) | |
| out_path = os.path.join(self.temp_dir, f"ref_{chosen}.wav") | |
| ta.save(out_path, concatenated, last_sr) | |
| return out_path, chosen | |
| except Exception as e: | |
| logger.warning(f"MLS concatenation failed for speaker {chosen}: {e}") | |
| return files[0], chosen |
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
| if os.path.exists(audio_path): | ||
| audio_data, _ = sf.read(audio_path) | ||
| else: | ||
| audio_data = self._generate_turn_audio( | ||
| text, reference_wav, conversation_id | ||
| ) | ||
| sf.write(audio_path, audio_data, self.sample_rate) | ||
| except OSError as e: | ||
| logger.error(f"File I/O failed for task {task.task_id}: {e}") | ||
| output_tasks.append(task) | ||
| continue |
There was a problem hiding this comment.
Silence files permanently cached, blocking retry on transient failures
When _generate_turn_audio returns silence (after a GPU OOM or other runtime error), the silence array is written to disk via sf.write(audio_path, audio_data, self.sample_rate) at line 485. On a subsequent re-run, os.path.exists(audio_path) returns True and the cached silence is served back as valid output without any retry. There is no way to distinguish a legitimately synthesised near-silence from a failure-generated silence, so transient failures permanently poison the idempotent cache for that task. If the intent is to allow retries, the silence fallback should not be written to disk — the task should be returned without audio_filepath, matching the empty-text passthrough pattern already used at line 516.
| key = f"{conversation_id}_{speaker}" | ||
|
|
||
| if key in self.speaker_to_reference: | ||
| return self.speaker_to_reference[key], self.speaker_to_ref_id[key] | ||
|
|
||
| if self._reference_layout == "mls": | ||
| already_taken_ids = { | ||
| self.speaker_to_ref_id[k] | ||
| for k in self.speaker_to_ref_id | ||
| if k.startswith(f"{conversation_id}_") | ||
| } |
There was a problem hiding this comment.
Key collision in
_assign_reference when conversation IDs or speaker names share an underscore boundary
The lookup key f"{conversation_id}_{speaker}" is ambiguous: (conv="A_B", speaker="C") and (conv="A", speaker="B_C") both produce key "A_B_C". When both pairs appear in the same batch (or across sessions), the second pair silently reuses the reference cached for the first, assigning it the wrong voice without any warning. The k.startswith(f"{conversation_id}_") deduplication check also has a symmetric false-positive: keys from a conversation whose ID starts with conversation_id + "_" are incorrectly counted as "already taken" voices. Using a tuple key (conversation_id, speaker) or a separator that cannot appear in either field (e.g. \x00) would eliminate both problems.
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
|
|
||
| def process(self, task: AudioTask) -> AudioTask: | ||
| """Generate audio for a single conversation turn.""" | ||
| return self.process_batch([task])[0] |
There was a problem hiding this comment.
We can just raise as not implemented instead.
| return stage | ||
|
|
||
|
|
||
| class TestConstruction: |
There was a problem hiding this comment.
None of the tests in this class are needed IMO.
| assert stage.resources.gpus == 1 | ||
|
|
||
|
|
||
| class TestModelLoading: |
There was a problem hiding this comment.
Same here, I don't think these tests are needed.
| ref_path, _ref_id = stage._assign_reference("Alice", "conv001") | ||
| assert os.path.exists(ref_path) | ||
|
|
||
| class TestEdgeCases: |
There was a problem hiding this comment.
Let's avoid the overly-specific class TestXYZ in favor of a single test class for each ProcessingStage.
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
| except (OSError, RuntimeError) as e: | ||
| logger.error(f"TTS generation failed: {e}") | ||
| return np.zeros(self.sample_rate * 2) |
There was a problem hiding this comment.
Narrow exception catch breaks graceful-failure contract
_generate_turn_audio only swallows OSError and RuntimeError. Chatterbox can raise ValueError (e.g. invalid audio_prompt_path format), AttributeError (model not fully initialised), TypeError (unexpected kwarg), or any other library-specific exception type. Those exceptions escape this handler, propagate into process_batch's outer try/except OSError (line 487) which also won't catch them, and then abort the rest of the batch loop entirely. Changing the inner handler to except Exception would keep every inference failure isolated to the one turn and preserve the silence-fallback contract across the board.
| if os.path.exists(audio_path): | ||
| audio_data, _ = sf.read(audio_path) | ||
| else: | ||
| audio_data = self._generate_turn_audio( | ||
| text, reference_wav, conversation_id | ||
| ) | ||
| sf.write(audio_path, audio_data, self.sample_rate) | ||
| except OSError as e: | ||
| logger.error(f"File I/O failed for task {task.task_id}: {e}") | ||
| output_tasks.append(task) | ||
| continue | ||
|
|
||
| duration = len(audio_data) / self.sample_rate | ||
|
|
||
| out_data = dict(data) | ||
| out_data["audio_filepath"] = audio_path | ||
| out_data["duration"] = duration | ||
| out_data["reference_voice"] = ref_id |
There was a problem hiding this comment.
Cross-run cache hit causes
reference_voice metadata to disagree with audio content
The output filename is derived solely from conversation_id + speaker + text (no reference voice component). self._rng is seeded from random.Random() (no fixed seed), and speaker_to_ref_id is cleared on teardown. On a fresh run the RNG assigns a different reference voice for the same (conversation_id, speaker) pair, say voice B instead of voice A. The file from the previous run already exists on disk (generated with A), so line 481 reads it back and line 497 records reference_voice = B. The audio content the caller receives sounds like voice A, but the metadata claims voice B — a silent discrepancy that poisons any downstream training pipeline that depends on reference_voice for speaker labelling. Including a stable hash of the selected ref_id in the filename, or recording the chosen reference ID alongside the cached file (e.g. a .json sidecar), would keep the audio and metadata aligned.
| if self.language: | ||
| os.environ["TRANSFORMERS_ATTN_IMPLEMENTATION"] = "eager" | ||
| try: | ||
| import chatterbox.models.t3.llama_configs as _llama_cfgs |
| """Load the TTS model and discover reference audio files.""" | ||
| os.makedirs(self.output_audio_dir, exist_ok=True) | ||
| self._init_temp_dir() | ||
| self._load_model() |
There was a problem hiding this comment.
Don't we need setup_on_node to download the relevant model first?
| def test_stage_properties(self, output_dir: str, ref_dataset: str) -> None: | ||
| stage = _build_stage(output_dir, ref_dataset) | ||
| assert stage.name == "ChatterboxTTSStage" | ||
| assert stage.resources.gpus == 1 | ||
|
|
||
| def test_invalid_language_raises(self, output_dir: str, ref_dataset: str) -> None: | ||
| with pytest.raises(ValueError, match="Unsupported language"): | ||
| _build_stage(output_dir, ref_dataset, language="xx") | ||
|
|
||
| def test_process_raises_not_implemented(self, output_dir: str, ref_dataset: str) -> None: | ||
| stage = _build_stage(output_dir, ref_dataset) | ||
| with pytest.raises(NotImplementedError, match="process_batch"): | ||
| stage.process(_make_task()) | ||
|
|
||
| def test_setup_raises_when_no_reference_audio( | ||
| self, output_dir: str, tmp_path: Path | ||
| ) -> None: | ||
| empty_dir = tmp_path / "empty" | ||
| empty_dir.mkdir() | ||
| stage = _build_stage(output_dir, str(empty_dir)) | ||
| with patch.object(ChatterboxTTSStage, "_load_model", _inject_model): | ||
| with pytest.raises(ValueError, match="No reference audio found"): | ||
| stage.setup() |
There was a problem hiding this comment.
These tests can be removed.
Description
Add a ChatterboxTTS-based speech synthesis stage (
ChatterboxTTSStage) to the NeMo Curator audio pipeline for generating multi-speaker conversation audio from text.New stage:
ChatterboxTTSStage— Synthesises conversation-turn audio using Chatterbox TTS. Supports both the English-only model (ChatterboxTTS) and the multilingual model (ChatterboxMultilingualTTS, 23 languages). Speaker voices are automatically assigned from a reference audio dataset and stay consistent within each conversation.Key features:
wavs/<dialog>/<speaker>.wav(with optional RTTM silence stripping) and MLS<spk>/<book>/<seg>.flac(auto-concatenated to target duration).Resources(gpus=1)).New files:
nemo_curator/stages/audio/tts/__init__.pynemo_curator/stages/audio/tts/chatterbox_tts.pytests/stages/audio/tts/__init__.pytests/stages/audio/tts/test_chatterbox_tts.py(55 tests)Usage
Supported languages (multilingual mode):
ar,da,de,el,en,es,fi,fr,he,hi,it,ja,ko,ms,nl,no,pl,pt,ru,sv,sw,tr,zhChecklist