Skip to content

Chatterbox tts#1976

Open
Ssofja wants to merge 7 commits into
NVIDIA-NeMo:mainfrom
Ssofja:chatterbox_tts
Open

Chatterbox tts#1976
Ssofja wants to merge 7 commits into
NVIDIA-NeMo:mainfrom
Ssofja:chatterbox_tts

Conversation

@Ssofja
Copy link
Copy Markdown
Contributor

@Ssofja Ssofja commented May 13, 2026

Description

Add a ChatterboxTTS-based speech synthesis stage (ChatterboxTTSStage) to the NeMo Curator audio pipeline for generating multi-speaker conversation audio from text.

New stage:

  • ChatterboxTTSStage — Synthesises conversation-turn audio using Chatterbox TTS. Supports both the English-only model (ChatterboxTTS) and the multilingual model (ChatterboxMultilingualTTS, 23 languages). Speaker voices are automatically assigned from a reference audio dataset and stay consistent within each conversation.

Key features:

  • Two reference audio layouts: wavs/<dialog>/<speaker>.wav (with optional RTTM silence stripping) and MLS <spk>/<book>/<seg>.flac (auto-concatenated to target duration).
  • Per-conversation random exaggeration range for voice style variation.
  • RMS-based audio normalisation with clipping protection.
  • Deterministic filenames (MD5-based) enabling idempotent re-runs — existing output files are reused without re-generation.
  • Graceful failure handling: TTS errors produce a silence placeholder instead of crashing the pipeline.
  • Requires 1 GPU (Resources(gpus=1)).

New files:

  • nemo_curator/stages/audio/tts/__init__.py
  • nemo_curator/stages/audio/tts/chatterbox_tts.py
  • tests/stages/audio/tts/__init__.py
  • tests/stages/audio/tts/test_chatterbox_tts.py (55 tests)

Usage

from nemo_curator.stages.audio import ChatterboxTTSStage

# English TTS with wavs/ reference layout
stage = ChatterboxTTSStage(
    output_audio_dir="/data/tts_output",
    reference_voices_dataset="/data/reference_voices",
    cfg_weight=0.5,
    exaggeration=0.5,
    temperature=0.8,
)

# Multilingual TTS (e.g. Russian) with MLS reference layout
stage_ru = ChatterboxTTSStage(
    output_audio_dir="/data/tts_output_ru",
    reference_voices_dataset="/data/mls_russian",
    language="ru",
    exaggeration=[0.3, 0.7],  # random per conversation
    max_reference_duration=30.0,
)

# Process conversation turns
from nemo_curator.tasks import AudioTask

tasks = [
    AudioTask(
        data={"utterance": "Hello, how are you?", "speaker": "Alice", "conversation_id": "conv001"},
        task_id="t1",
        dataset_name="my_dataset",
    ),
    AudioTask(
        data={"utterance": "I'm doing well, thanks!", "speaker": "Bob", "conversation_id": "conv001"},
        task_id="t2",
        dataset_name="my_dataset",
    ),
]

results = stage.process_batch(tasks)
# Each result.data now contains "audio_filepath", "duration", and "reference_voice"

Supported languages (multilingual mode):
ar, da, de, el, en, es, fi, fr, he, hi, it, ja, ko, ms, nl, no, pl, pt, ru, sv, sw, tr, zh

Checklist

  • I am familiar with the Contributing Guide.
  • New or Existing tests cover these changes.
  • The documentation is up to date with these changes.

Ssofja added 2 commits May 13, 2026 17:05
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
@Ssofja Ssofja requested a review from a team as a code owner May 13, 2026 13:13
@Ssofja Ssofja requested review from sarahyurick and removed request for a team May 13, 2026 13:13
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR introduces ChatterboxTTSStage, a new audio pipeline stage that synthesises multi-speaker conversation audio from text using the Chatterbox TTS library, supporting both an English-only model and a 23-language multilingual variant with reference-voice cloning.

  • New stage (chatterbox_tts.py): handles reference-audio discovery in two layouts (wavs/RTTM and MLS), deterministic output filenames for idempotent re-runs, per-conversation exaggeration ranges, RMS normalisation, and silence-placeholder fallback on inference failure.
  • Public API wired up: ChatterboxTTSStage is exported from nemo_curator/stages/audio/__init__.py via a new tts/ sub-package.
  • 55-test suite covers both reference layouts, multi-turn speaker consistency, idempotency, empty-text passthrough, and generation-failure fallback — all using device="cpu" mock models.

Confidence Score: 2/5

Not safe to merge as-is: inference failures from non-OSError/RuntimeError exceptions will abort the entire batch rather than being isolated per-turn, and cross-run cache hits produce output tasks where the reference_voice metadata does not match the audio actually on disk.

Two defects affect correctness on the main code path. First, _generate_turn_audio swallows only OSError and RuntimeError, so any other exception from the Chatterbox library escapes both the inner handler and the outer except OSError in process_batch, terminating the batch loop mid-flight. Second, the deterministic output filename encodes only conversation_id + speaker + text, not which reference voice was selected; a fresh run may assign a different reference and return cached audio with incorrect reference_voice metadata.

nemo_curator/stages/audio/tts/chatterbox_tts.py — both defects are concentrated here, in _generate_turn_audio (exception handling) and process_batch (cache-hit metadata path).

Important Files Changed

Filename Overview
nemo_curator/stages/audio/tts/chatterbox_tts.py Core TTS stage implementation; has two defects: narrow exception handling in _generate_turn_audio breaks batch-level graceful failure, and the deterministic filename scheme doesn't encode the selected reference voice, producing metadata/audio mismatches on re-runs.
nemo_curator/stages/audio/tts/init.py New package init exposing ChatterboxTTSStage; straightforward re-export, no issues.
nemo_curator/stages/audio/init.py Adds ChatterboxTTSStage to the audio stage public API; import and all are consistent.
tests/stages/audio/tts/test_chatterbox_tts.py 55-test suite covering setup, teardown, both reference layouts, idempotency, and failure modes; all tests use device='cpu' which sidesteps GPU-specific paths and doesn't exercise cross-run cache re-assignment.
tests/stages/audio/tts/init.py Empty test package marker; no issues.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant Stage as ChatterboxTTSStage
    participant Model as ChatterboxTTS / Multilingual
    participant RefAudio as Reference Dataset
    participant Disk as Output Dir / Temp Dir

    Caller->>Stage: setup()
    Stage->>Disk: makedirs(output_audio_dir)
    Stage->>Disk: "mkdtemp(chatterbox_ref_*)"
    Stage->>Model: from_pretrained(device)
    Stage->>RefAudio: "glob wavs/*/*.wav or */*/*.flac"
    Stage-->>Caller: ready

    Caller->>Stage: process_batch(tasks)
    loop for each AudioTask
        Stage->>Stage: _assign_reference(speaker, conv_id)
        alt new (conv_id, speaker) pair
            Stage->>RefAudio: pick reference file
            alt wavs layout + RTTM exists
                Stage->>Disk: strip silences → temp_dir/ref.wav
            else MLS layout
                Stage->>Disk: concatenate segments → temp_dir/ref_spk.wav
            end
        end
        Stage->>Stage: _output_filename(conv_id, speaker, text)
        alt file already exists on disk
            Stage->>Disk: sf.read(audio_path)
        else
            Stage->>Model: generate(text, audio_prompt_path, ...)
            Model-->>Stage: wav tensor
            Stage->>Stage: _normalize_audio(wav)
            Stage->>Disk: sf.write(audio_path, audio_data)
        end
        Stage-->>Caller: AudioTask + audio_filepath, duration, reference_voice
    end

    Caller->>Stage: teardown()
    Stage->>Disk: shutil.rmtree(temp_dir)
Loading

Reviews (6): Last reviewed commit: "change chatterbox tests structure" | Re-trigger Greptile

Comment on lines +118 to +122
if language is not None and language.lower() not in SUPPORTED_LANGUAGES:
raise ValueError(
f"Unsupported language '{language}'. "
f"Supported: {', '.join(sorted(SUPPORTED_LANGUAGES))}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The language code is validated with .lower() but stored as-is and later passed directly to the model as language_id. If a caller passes "RU" or "FR", it clears the SUPPORTED_LANGUAGES check (because "ru" is in the set), but the raw uppercase string is forwarded to ChatterboxMultilingualTTS.generate. The Chatterbox API expects lowercase ISO 639-1 codes, so inference would either fail or silently produce wrong-language output.

Suggested change
if language is not None and language.lower() not in SUPPORTED_LANGUAGES:
raise ValueError(
f"Unsupported language '{language}'. "
f"Supported: {', '.join(sorted(SUPPORTED_LANGUAGES))}"
)
if language is not None and language.lower() not in SUPPORTED_LANGUAGES:
raise ValueError(
f"Unsupported language '{language}'. "
f"Supported: {', '.join(sorted(SUPPORTED_LANGUAGES))}"
)
if language is not None:
language = language.lower()

Comment on lines +265 to +267
out_path = os.path.join(
self.temp_dir, os.path.basename(audio_filepath)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Both _get_reference_audio_wavs and _get_reference_audio_mls write their RTTM-processed / concatenated output into self.temp_dir using os.path.basename(audio_filepath). When different dialogs contain a speaker file with the same name (e.g. dialog001/Alice.wav and dialog002/Alice.wav), the second write silently overwrites the first temp file. Any speaker already assigned temp_dir/Alice.wav then synthesises audio with the wrong voice without any warning.

Suggested change
out_path = os.path.join(
self.temp_dir, os.path.basename(audio_filepath)
)
unique_name = hashlib.md5(audio_filepath.encode()).hexdigest()[:8] + "_" + os.path.basename(audio_filepath)
out_path = os.path.join(self.temp_dir, unique_name)

Comment on lines +417 to +422
@staticmethod
def _output_filename(conversation_id: str, speaker: str, text: str) -> str:
"""Deterministic filename: ``{conv_id_short}_{speaker}_{text_hash}.wav``."""
conv_short = conversation_id[:12] if len(conversation_id) > 12 else conversation_id
text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10]
return f"{conv_short}_{speaker}_{text_hash}.wav"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Truncating the conversation ID to 12 characters means two conversations whose IDs share a 12-character prefix (common with structured IDs such as session1_conv001 / session1_conv002) generate the same filename for the same speaker and text. On a subsequent run the cached file from the first conversation is reused for the second even though a different reference voice may have been assigned, producing a silent audio/metadata mismatch.

Suggested change
@staticmethod
def _output_filename(conversation_id: str, speaker: str, text: str) -> str:
"""Deterministic filename: ``{conv_id_short}_{speaker}_{text_hash}.wav``."""
conv_short = conversation_id[:12] if len(conversation_id) > 12 else conversation_id
text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10]
return f"{conv_short}_{speaker}_{text_hash}.wav"
@staticmethod
def _output_filename(conversation_id: str, speaker: str, text: str) -> str:
"""Deterministic filename: ``{conv_id_hash}_{speaker}_{text_hash}.wav``."""
conv_hash = hashlib.md5(conversation_id.encode("utf-8")).hexdigest()[:12]
text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10]
return f"{conv_hash}_{speaker}_{text_hash}.wav"

Comment on lines +460 to +478
reference_wav = self._assign_reference(speaker, conversation_id)

filename = self._output_filename(conversation_id, speaker, text)
audio_path = os.path.join(self.output_audio_dir, filename)

if os.path.exists(audio_path):
audio_data, _ = sf.read(audio_path)
else:
audio_data = self._generate_turn_audio(
text, reference_wav, conversation_id
)
sf.write(audio_path, audio_data, self.sample_rate)

duration = len(audio_data) / self.sample_rate

out_data = dict(data)
out_data["audio_filepath"] = audio_path
out_data["duration"] = duration
out_data["reference_voice"] = Path(reference_wav).parent.name
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Path(reference_wav).parent.name returns the temp-directory name (e.g. chatterbox_ref_abc123) whenever the reference has been RTTM-processed or comes from the MLS layout, because both code paths write to self.temp_dir/<filename>. Only the raw wavs path (no RTTM) has a meaningful parent (the dialog ID). The emitted reference_voice value should be the MLS speaker ID or the dialog/speaker tag, not an ephemeral temp-dir name.

Suggested change
reference_wav = self._assign_reference(speaker, conversation_id)
filename = self._output_filename(conversation_id, speaker, text)
audio_path = os.path.join(self.output_audio_dir, filename)
if os.path.exists(audio_path):
audio_data, _ = sf.read(audio_path)
else:
audio_data = self._generate_turn_audio(
text, reference_wav, conversation_id
)
sf.write(audio_path, audio_data, self.sample_rate)
duration = len(audio_data) / self.sample_rate
out_data = dict(data)
out_data["audio_filepath"] = audio_path
out_data["duration"] = duration
out_data["reference_voice"] = Path(reference_wav).parent.name
reference_wav, ref_id = self._assign_reference(speaker, conversation_id)
filename = self._output_filename(conversation_id, speaker, text)
audio_path = os.path.join(self.output_audio_dir, filename)
if os.path.exists(audio_path):
audio_data, _ = sf.read(audio_path)
else:
audio_data = self._generate_turn_audio(
text, reference_wav, conversation_id
)
sf.write(audio_path, audio_data, self.sample_rate)
duration = len(audio_data) / self.sample_rate
out_data = dict(data)
out_data["audio_filepath"] = audio_path
out_data["duration"] = duration
out_data["reference_voice"] = ref_id

top_p: float = 1.0,
normalize_audio: bool = True,
normalize_level: float = -20.0,
**kwargs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be removed?

Comment thread nemo_curator/stages/audio/tts/chatterbox_tts.py
text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10]
return f"{conv_short}_{speaker}_{text_hash}.wav"

def _ensure_ready(self) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this.

if not tasks:
return []

self._ensure_ready()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove. We should never call setup in process_batch/process.

return []

self._ensure_ready()
os.makedirs(self.output_audio_dir, exist_ok=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in setup?

if self.normalize_audio:
wav = self._normalize_audio(wav)

return wav.squeeze(0).numpy()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing .cpu() — CUDA inference always silently fails

wav.squeeze(0).numpy() is called directly on the tensor returned by self.model.generate(). The default device is "cuda", which means the model lives on GPU and its output is a CUDA tensor. Calling .numpy() on a CUDA tensor raises RuntimeError: can't convert a given torch.Tensor to numpy, which is swallowed by the surrounding except Exception block, so every single GPU inference call silently falls through and returns 2 seconds of zeros. The tests never catch this because _build_stage hard-codes device="cpu".

Suggested change
return wav.squeeze(0).numpy()
return wav.squeeze(0).cpu().numpy()

Signed-off-by: Ssofja <sofiakostandian@gmail.com>
Comment on lines +468 to +489
if os.path.exists(audio_path):
audio_data, _ = sf.read(audio_path)
else:
audio_data = self._generate_turn_audio(
text, reference_wav, conversation_id
)
sf.write(audio_path, audio_data, self.sample_rate)

duration = len(audio_data) / self.sample_rate

out_data = dict(data)
out_data["audio_filepath"] = audio_path
out_data["duration"] = duration
out_data["reference_voice"] = ref_id

output_tasks.append(
AudioTask(
data=out_data,
task_id=task.task_id,
dataset_name=task.dataset_name,
)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The PR's stated design goal is "graceful failure handling: TTS errors produce a silence placeholder instead of crashing the pipeline," but only _generate_turn_audio has a try/except. The sf.write call (and the preceding os.path.exists + sf.read for cache hits) sits outside any guard. Any OSError here — disk full, wrong permissions, or a speaker name that contains / or \ generating a path whose parent directory doesn't exist — will raise an unhandled exception and abort the entire remaining batch. Wrapping the file I/O block in its own try/except keeps the one-turn failure contained and matches the silence-fallback contract already in place for inference.

Suggested change
if os.path.exists(audio_path):
audio_data, _ = sf.read(audio_path)
else:
audio_data = self._generate_turn_audio(
text, reference_wav, conversation_id
)
sf.write(audio_path, audio_data, self.sample_rate)
duration = len(audio_data) / self.sample_rate
out_data = dict(data)
out_data["audio_filepath"] = audio_path
out_data["duration"] = duration
out_data["reference_voice"] = ref_id
output_tasks.append(
AudioTask(
data=out_data,
task_id=task.task_id,
dataset_name=task.dataset_name,
)
)
try:
if os.path.exists(audio_path):
audio_data, _ = sf.read(audio_path)
else:
audio_data = self._generate_turn_audio(
text, reference_wav, conversation_id
)
sf.write(audio_path, audio_data, self.sample_rate)
except Exception as e:
logger.error(f"File I/O failed for task {task.task_id}: {e}")
output_tasks.append(task)
continue
duration = len(audio_data) / self.sample_rate
out_data = dict(data)
out_data["audio_filepath"] = audio_path
out_data["duration"] = duration
out_data["reference_voice"] = ref_id
output_tasks.append(
AudioTask(
data=out_data,
task_id=task.task_id,
dataset_name=task.dataset_name,
)
)

Comment on lines +335 to +341
if not chunks:
return files[0], chosen

concatenated = torch.cat(chunks, dim=1)
out_path = os.path.join(self.temp_dir, f"ref_{chosen}.wav")
ta.save(out_path, concatenated, last_sr)
return out_path, chosen
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The torch.cat call and the ta.save that follows sit outside the per-file try/except. If any two FLAC files in chunks have a different number of channels (e.g., one mono, one stereo), torch.cat(chunks, dim=1) raises a RuntimeError that propagates all the way out of process_batch uncaught, aborting the batch entirely. Wrapping the concatenation and save in their own guard — and falling back to the raw first file, matching the existing empty-chunks fallback — keeps the failure isolated to the one speaker assignment.

Suggested change
if not chunks:
return files[0], chosen
concatenated = torch.cat(chunks, dim=1)
out_path = os.path.join(self.temp_dir, f"ref_{chosen}.wav")
ta.save(out_path, concatenated, last_sr)
return out_path, chosen
if not chunks:
return files[0], chosen
try:
concatenated = torch.cat(chunks, dim=1)
out_path = os.path.join(self.temp_dir, f"ref_{chosen}.wav")
ta.save(out_path, concatenated, last_sr)
return out_path, chosen
except Exception as e:
logger.warning(f"MLS concatenation failed for speaker {chosen}: {e}")
return files[0], chosen

Signed-off-by: Ssofja <sofiakostandian@gmail.com>
Comment on lines +479 to +489
if os.path.exists(audio_path):
audio_data, _ = sf.read(audio_path)
else:
audio_data = self._generate_turn_audio(
text, reference_wav, conversation_id
)
sf.write(audio_path, audio_data, self.sample_rate)
except OSError as e:
logger.error(f"File I/O failed for task {task.task_id}: {e}")
output_tasks.append(task)
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silence files permanently cached, blocking retry on transient failures

When _generate_turn_audio returns silence (after a GPU OOM or other runtime error), the silence array is written to disk via sf.write(audio_path, audio_data, self.sample_rate) at line 485. On a subsequent re-run, os.path.exists(audio_path) returns True and the cached silence is served back as valid output without any retry. There is no way to distinguish a legitimately synthesised near-silence from a failure-generated silence, so transient failures permanently poison the idempotent cache for that task. If the intent is to allow retries, the silence fallback should not be written to disk — the task should be returned without audio_filepath, matching the empty-text passthrough pattern already used at line 516.

Comment on lines +359 to +369
key = f"{conversation_id}_{speaker}"

if key in self.speaker_to_reference:
return self.speaker_to_reference[key], self.speaker_to_ref_id[key]

if self._reference_layout == "mls":
already_taken_ids = {
self.speaker_to_ref_id[k]
for k in self.speaker_to_ref_id
if k.startswith(f"{conversation_id}_")
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Key collision in _assign_reference when conversation IDs or speaker names share an underscore boundary

The lookup key f"{conversation_id}_{speaker}" is ambiguous: (conv="A_B", speaker="C") and (conv="A", speaker="B_C") both produce key "A_B_C". When both pairs appear in the same batch (or across sessions), the second pair silently reuses the reference cached for the first, assigning it the wrong voice without any warning. The k.startswith(f"{conversation_id}_") deduplication check also has a symmetric false-positive: keys from a conversation whose ID starts with conversation_id + "_" are incorrectly counted as "already taken" voices. Using a tuple key (conversation_id, speaker) or a separator that cannot appear in either field (e.g. \x00) would eliminate both problems.

Signed-off-by: Ssofja <sofiakostandian@gmail.com>

def process(self, task: AudioTask) -> AudioTask:
"""Generate audio for a single conversation turn."""
return self.process_batch([task])[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just raise as not implemented instead.

return stage


class TestConstruction:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the tests in this class are needed IMO.

assert stage.resources.gpus == 1


class TestModelLoading:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I don't think these tests are needed.

ref_path, _ref_id = stage._assign_reference("Alice", "conv001")
assert os.path.exists(ref_path)

class TestEdgeCases:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid the overly-specific class TestXYZ in favor of a single test class for each ProcessingStage.

Signed-off-by: Ssofja <sofiakostandian@gmail.com>
Comment on lines +420 to +422
except (OSError, RuntimeError) as e:
logger.error(f"TTS generation failed: {e}")
return np.zeros(self.sample_rate * 2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Narrow exception catch breaks graceful-failure contract

_generate_turn_audio only swallows OSError and RuntimeError. Chatterbox can raise ValueError (e.g. invalid audio_prompt_path format), AttributeError (model not fully initialised), TypeError (unexpected kwarg), or any other library-specific exception type. Those exceptions escape this handler, propagate into process_batch's outer try/except OSError (line 487) which also won't catch them, and then abort the rest of the batch loop entirely. Changing the inner handler to except Exception would keep every inference failure isolated to the one turn and preserve the silence-fallback contract across the board.

Comment on lines +480 to +497
if os.path.exists(audio_path):
audio_data, _ = sf.read(audio_path)
else:
audio_data = self._generate_turn_audio(
text, reference_wav, conversation_id
)
sf.write(audio_path, audio_data, self.sample_rate)
except OSError as e:
logger.error(f"File I/O failed for task {task.task_id}: {e}")
output_tasks.append(task)
continue

duration = len(audio_data) / self.sample_rate

out_data = dict(data)
out_data["audio_filepath"] = audio_path
out_data["duration"] = duration
out_data["reference_voice"] = ref_id
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Cross-run cache hit causes reference_voice metadata to disagree with audio content

The output filename is derived solely from conversation_id + speaker + text (no reference voice component). self._rng is seeded from random.Random() (no fixed seed), and speaker_to_ref_id is cleared on teardown. On a fresh run the RNG assigns a different reference voice for the same (conversation_id, speaker) pair, say voice B instead of voice A. The file from the previous run already exists on disk (generated with A), so line 481 reads it back and line 497 records reference_voice = B. The audio content the caller receives sounds like voice A, but the metadata claims voice B — a silent discrepancy that poisons any downstream training pipeline that depends on reference_voice for speaker labelling. Including a stable hash of the selected ref_id in the filename, or recording the chosen reference ID alongside the cached file (e.g. a .json sidecar), would keep the audio and metadata aligned.

if self.language:
os.environ["TRANSFORMERS_ATTN_IMPLEMENTATION"] = "eager"
try:
import chatterbox.models.t3.llama_configs as _llama_cfgs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top-level import?

"""Load the TTS model and discover reference audio files."""
os.makedirs(self.output_audio_dir, exist_ok=True)
self._init_temp_dir()
self._load_model()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need setup_on_node to download the relevant model first?

Comment on lines +126 to +148
def test_stage_properties(self, output_dir: str, ref_dataset: str) -> None:
stage = _build_stage(output_dir, ref_dataset)
assert stage.name == "ChatterboxTTSStage"
assert stage.resources.gpus == 1

def test_invalid_language_raises(self, output_dir: str, ref_dataset: str) -> None:
with pytest.raises(ValueError, match="Unsupported language"):
_build_stage(output_dir, ref_dataset, language="xx")

def test_process_raises_not_implemented(self, output_dir: str, ref_dataset: str) -> None:
stage = _build_stage(output_dir, ref_dataset)
with pytest.raises(NotImplementedError, match="process_batch"):
stage.process(_make_task())

def test_setup_raises_when_no_reference_audio(
self, output_dir: str, tmp_path: Path
) -> None:
empty_dir = tmp_path / "empty"
empty_dir.mkdir()
stage = _build_stage(output_dir, str(empty_dir))
with patch.object(ChatterboxTTSStage, "_load_model", _inject_model):
with pytest.raises(ValueError, match="No reference audio found"):
stage.setup()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests can be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants