diff --git a/nemo_curator/stages/audio/__init__.py b/nemo_curator/stages/audio/__init__.py index df541d370e..f0f455e4a0 100644 --- a/nemo_curator/stages/audio/__init__.py +++ b/nemo_curator/stages/audio/__init__.py @@ -24,6 +24,7 @@ """ from nemo_curator.stages.audio.advanced_pipelines import AudioDataFilterStage +from nemo_curator.stages.audio.alignment import MFAAlignmentStage from nemo_curator.stages.audio.alm import ALMDataBuilderStage, ALMDataOverlapStage from nemo_curator.stages.audio.common import ( GetAudioDurationStage, @@ -56,6 +57,7 @@ "GetAudioDurationStage", "ManifestReader", "ManifestWriterStage", + "MFAAlignmentStage", "MonoConversionStage", "PreserveByValueStage", "SIGMOSFilterStage", diff --git a/nemo_curator/stages/audio/alignment/__init__.py b/nemo_curator/stages/audio/alignment/__init__.py new file mode 100644 index 0000000000..3ec814adeb --- /dev/null +++ b/nemo_curator/stages/audio/alignment/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MFA forced-alignment stage for audio curation.""" + +from nemo_curator.stages.audio.alignment.mfa_alignment import MFAAlignmentStage + +__all__ = ["MFAAlignmentStage"] diff --git a/nemo_curator/stages/audio/alignment/mfa_alignment.py b/nemo_curator/stages/audio/alignment/mfa_alignment.py new file mode 100644 index 0000000000..f0f524a9e0 --- /dev/null +++ b/nemo_curator/stages/audio/alignment/mfa_alignment.py @@ -0,0 +1,620 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MFA Batch Alignment Stage for NeMo Curator. + +A ``ProcessingStage`` that runs `Montreal Forced Aligner (MFA) +`_ in batch mode on a set +of ``AudioTask`` entries, producing TextGrid, RTTM, and/or CTM output files. + +The stage operates via ``process_batch``: it collects all tasks in a batch, +prepares a temporary MFA corpus (symlinked WAVs + ``.txt`` transcript files), +runs a single ``mfa align`` subprocess, and converts the resulting TextGrid +files to RTTM and/or CTM format depending on configuration. + +Node-level isolation + ``setup_on_node()`` copies MFA models from shared storage to a node-local + directory. This avoids NFS/Lustre race conditions and Kaldi errors when + multiple distributed nodes share the same model directory. + +Worker scheduling + ``xenna_stage_spec()`` returns ``{"num_workers_per_node": 1}`` to + guarantee exactly one MFA worker per node. +""" + +from __future__ import annotations + +import os +import shlex +import shutil +import socket +import subprocess +import tempfile +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import soundfile as sf +from loguru import logger +from praatio import textgrid as praatio_textgrid + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import AudioTask + +_DEFAULT_SILENCE_MARKERS = ("", "sp", "sil", "spn", "") +_WORD_TIER_NAMES = ("words", "word") +_PHONE_TIER_NAMES = frozenset({ + "phones", + "phone", + "phonemes", + "phoneme", + "phons", +}) + + +@dataclass +class MFAAlignmentStage(ProcessingStage[AudioTask, AudioTask]): + """Batch forced alignment using the Montreal Forced Aligner (MFA). + + This stage only supports :meth:`process_batch`; calling :meth:`process` + raises ``NotImplementedError``. Use ``.with_(batch_size=N)`` to control + how many tasks are aligned per ``mfa align`` invocation. + + Args: + mfa_command: Shell command (or absolute path) to the ``mfa`` binary. + acoustic_model: MFA acoustic model name or path. + dictionary: MFA dictionary name or path. + g2p_model: MFA G2P model for out-of-vocabulary words. Set to ``""`` + to disable G2P. + output_dir: Root directory for all output files. Sub-directories + ``textgrids/``, ``rttms/``, and ``ctms/`` are created beneath it. + audio_filepath_key: Key in ``task.data`` pointing to the WAV file. + text_key: Key in ``task.data`` containing the transcript text. + speaker_key: Key in ``task.data`` for the speaker label (used in + RTTM output). + duration_key: Key in ``task.data`` for audio duration. Computed + automatically if missing. + max_gap_for_merge: Maximum gap (seconds) between speech intervals + before they are merged in the RTTM output. + num_jobs: Number of parallel MFA jobs (``-j`` flag passed to MFA). + Set explicitly for your deployment; not inferred by the stage. + beam: MFA beam size for alignment search. + retry_beam: MFA retry beam when initial alignment fails. + single_speaker: Pass ``--single_speaker`` to MFA. + clean: Pass ``--clean`` to MFA (remove temp files after alignment). + use_mp: Pass ``--use_mp`` to MFA (use multiprocessing). + output_format: MFA output format (``long_textgrid`` or + ``short_textgrid``). + mfa_root_dir: MFA root directory containing pretrained models, or + ``None`` to use ``MFA_ROOT_DIR`` / ``~/.mfa``. + local_mfa_base_dir: Base directory for node-local model copies, or + ``None`` to use ``tempfile.gettempdir()`` (typically ``/tmp``). + copy_models_to_local: Whether ``setup_on_node`` should copy models + to node-local storage. + silence_markers: Labels to treat as silence when converting TextGrids. + create_rttm: Whether to convert TextGrids to RTTM files. + create_ctm: Whether to convert TextGrids to CTM files. + """ + + output_dir: str + name: str = "MFAAlignmentStage" + mfa_command: str = "mfa" + acoustic_model: str = "english_us_arpa" + dictionary: str = "english_us_arpa" + g2p_model: str = "english_us_arpa" + audio_filepath_key: str = "audio_filepath" + text_key: str = "text" + speaker_key: str = "speaker" + duration_key: str = "duration" + max_gap_for_merge: float = 0.3 + num_jobs: int = 1 + beam: int = 100 + retry_beam: int = 400 + single_speaker: bool = True + clean: bool = True + use_mp: bool = True + output_format: str = "long_textgrid" + mfa_root_dir: str | None = None + local_mfa_base_dir: str | None = None + copy_models_to_local: bool = True + silence_markers: tuple[str, ...] = _DEFAULT_SILENCE_MARKERS + create_rttm: bool = True + create_ctm: bool = True + + # Set during lifecycle hooks -- not user-configurable + _mfa_root: str = field(default="", init=False, repr=False) + _textgrid_mod: Any = field(default=praatio_textgrid, init=False, repr=False) + + def __post_init__(self) -> None: + self._effective_mfa_root = self.mfa_root_dir or os.environ.get( + "MFA_ROOT_DIR", os.path.expanduser("~/.mfa") + ) + self._effective_local_base = ( + self.local_mfa_base_dir or tempfile.gettempdir() + ) + self._textgrid_dir = Path(self.output_dir) / "textgrids" + self._rttm_dir = Path(self.output_dir) / "rttms" + self._ctm_dir = Path(self.output_dir) / "ctms" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.audio_filepath_key, self.text_key] + + def outputs(self) -> tuple[list[str], list[str]]: + data_keys = ["textgrid_filepath"] + if self.create_rttm: + data_keys.append("rttm_filepath") + if self.create_ctm: + data_keys.append("ctm_filepath") + return [], data_keys + + def xenna_stage_spec(self) -> dict[str, Any]: + # Current implementation is meant to run with one worker per node. because the MFA library has issues when running in parallel. + # We are copying the MFA models to node-local storage to avoid race conditions and Kaldi errors when multiple distributed nodes share the same model directory. + return {"num_workers_per_node": 1} + + def setup_on_node( + self, + node_info: Any = None, # noqa: ARG002, ANN401 + worker_metadata: Any = None, # noqa: ARG002, ANN401 + ) -> None: + """Copy MFA models from shared storage to node-local directory.""" + if not self.copy_models_to_local: + self._mfa_root = self._effective_mfa_root + return + hostname = socket.gethostname() + self._mfa_root = self._setup_local_mfa( + self._effective_mfa_root, hostname + ) + logger.info( + f"[setup_on_node] MFA root set to {self._mfa_root} on {hostname}" + ) + + def setup( + self, + worker_metadata: Any = None, # noqa: ARG002, ANN401 + ) -> None: + """Resolve the MFA root and create output directories.""" + if not self._mfa_root: + if self.copy_models_to_local: + hostname = socket.gethostname() + local_candidate = ( + Path(self._effective_local_base) / f"mfa_models_{hostname}" + ) + if local_candidate.exists(): + self._mfa_root = str(local_candidate) + logger.info( + f"[setup] Re-using local MFA root: {self._mfa_root}" + ) + else: + self._mfa_root = self._effective_mfa_root + logger.info( + f"[setup] Local copy not found; using shared MFA root: " + f"{self._mfa_root}" + ) + else: + self._mfa_root = self._effective_mfa_root + + self._textgrid_dir.mkdir(parents=True, exist_ok=True) + if self.create_rttm: + self._rttm_dir.mkdir(parents=True, exist_ok=True) + if self.create_ctm: + self._ctm_dir.mkdir(parents=True, exist_ok=True) + + def process(self, task: AudioTask) -> AudioTask: + msg = "MFAAlignmentStage only supports process_batch" + raise NotImplementedError(msg) + + def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + """Align all tasks in a single ``mfa align`` invocation.""" + if not tasks: + return [] + + stem_to_task: dict[str, AudioTask] = {} + for task in tasks: + if not self.validate_input(task): + msg = f"Task {task!s} failed validation for stage {self}" + raise ValueError(msg) + audio_filepath = task.data[self.audio_filepath_key] + text = task.data[self.text_key].strip() + if not text: + msg = ( + f"Empty text for {audio_filepath} " + f"(key={self.text_key!r})" + ) + raise ValueError(msg) + audio_path = Path(audio_filepath) + if not audio_path.exists(): + msg = f"Audio file not found: {audio_path}" + raise FileNotFoundError(msg) + + file_stem = audio_path.stem + if file_stem in stem_to_task: + original_stem = file_stem + file_stem = f"{file_stem}_{uuid.uuid4().hex[:8]}" + logger.warning( + f"Duplicate stem '{original_stem}' — renamed to " + f"'{file_stem}' to avoid silent data loss" + ) + if not task.data.get(self.duration_key): + task.data[self.duration_key] = self._get_audio_duration( + str(audio_path) + ) + stem_to_task[file_stem] = task + + batch_uuid = uuid.uuid4().hex[:12] + tg_out_path = self._textgrid_dir / batch_uuid + tg_out_path.mkdir(parents=True, exist_ok=True) + + results: list[AudioTask] = [] + + with tempfile.TemporaryDirectory(prefix="mfa_corpus_") as corpus_dir: + corpus_path = Path(corpus_dir) + for corpus_stem, task in stem_to_task.items(): + audio_path = Path(task.data[self.audio_filepath_key]) + corpus_wav = corpus_path / f"{corpus_stem}.wav" + if not corpus_wav.exists() and not corpus_wav.is_symlink(): + try: + corpus_wav.symlink_to(audio_path.resolve()) + except OSError: + shutil.copy2(audio_path, corpus_wav) + corpus_txt = corpus_path / f"{corpus_stem}.txt" + corpus_txt.write_text( + task.data[self.text_key].strip(), encoding="utf-8" + ) + + self._run_mfa_align(corpus_path, tg_out_path) + + all_tg = { + tg.stem: tg for tg in tg_out_path.rglob("*.TextGrid") + } + missing = {s for s in stem_to_task if s not in all_tg} + + if missing: + logger.warning( + f"MFA silently dropped {len(missing)}/{len(stem_to_task)} " + f"files (exit code was 0). Creating fallback outputs." + ) + + for file_stem, task in stem_to_task.items(): + if file_stem in missing: + self._handle_missing_textgrid(file_stem, task) + else: + self._handle_successful_textgrid( + file_stem, task, all_tg[file_stem] + ) + results.append(task) + + return results + + def _handle_successful_textgrid( + self, file_stem: str, task: AudioTask, tg_path: Path + ) -> None: + task.data["textgrid_filepath"] = str(tg_path) + speaker = task.data.get(self.speaker_key, "unknown") + + if self.create_rttm: + rttm_path = self._rttm_dir / f"{file_stem}.rttm" + self._textgrid_to_rttm(tg_path, file_stem, speaker, rttm_path) + task.data["rttm_filepath"] = str(rttm_path) + + if self.create_ctm: + ctm_path = self._ctm_dir / f"{file_stem}.ctm" + self._textgrid_to_ctm(tg_path, file_stem, ctm_path) + task.data["ctm_filepath"] = str(ctm_path) + + def _handle_missing_textgrid( + self, file_stem: str, task: AudioTask + ) -> None: + duration = task.data.get(self.duration_key, 0.0) + text = task.data.get(self.text_key, "").strip() + speaker = task.data.get(self.speaker_key, "unknown") + + logger.warning( + f" MFA dropped '{file_stem}': duration={duration:.2f}s, " + f"text='{text[:120]}'" + ) + + task.data["textgrid_filepath"] = "" + task.data["mfa_skipped"] = True + + if self.create_rttm: + rttm_path = self._rttm_dir / f"{file_stem}.rttm" + self._create_duration_fallback_rttm( + file_stem, speaker, duration, rttm_path + ) + task.data["rttm_filepath"] = str(rttm_path) + + if self.create_ctm: + ctm_path = self._ctm_dir / f"{file_stem}.ctm" + self._create_duration_fallback_ctm( + file_stem, text, duration, ctm_path + ) + task.data["ctm_filepath"] = str(ctm_path) + + def _run_mfa_align( + self, corpus_dir: Path, textgrid_output_dir: Path + ) -> None: + env = os.environ.copy() + env["MFA_ROOT_DIR"] = self._mfa_root + + mfa_cmd_parts = shlex.split(self.mfa_command) + mfa_bin_dir = ( + os.path.dirname(mfa_cmd_parts[0]) + if os.path.isabs(mfa_cmd_parts[0]) + else None + ) + if mfa_bin_dir: + env["PATH"] = f"{mfa_bin_dir}:{env.get('PATH', '')}" + + history_file = Path(self._mfa_root) / "command_history.yaml" + if history_file.exists() and self._is_node_local_mfa_root(): + try: + history_file.unlink() + except OSError: + logger.debug( + f"Could not remove MFA history file: {history_file}" + ) + + cmd = mfa_cmd_parts + [ + "align", + str(corpus_dir), + self.dictionary, + self.acoustic_model, + str(textgrid_output_dir), + "--output_format", + self.output_format, + "-j", + str(self.num_jobs), + "--beam", + str(self.beam), + "--retry_beam", + str(self.retry_beam), + ] + if self.single_speaker: + cmd.append("--single_speaker") + if self.use_mp: + cmd.append("--use_mp") + if self.clean: + cmd.append("--clean") + + if self.g2p_model: + g2p_path = ( + Path(self._mfa_root) + / "pretrained_models" + / "g2p" + / f"{self.g2p_model}.zip" + ) + if g2p_path.exists(): + cmd.extend(["--g2p_model_path", str(g2p_path)]) + else: + g2p_alt = ( + Path(self._mfa_root) + / "pretrained_models" + / "g2p" + / self.g2p_model + ) + if g2p_alt.exists(): + cmd.extend(["--g2p_model_path", str(g2p_alt)]) + else: + logger.warning( + f"G2P model '{self.g2p_model}' not found at " + f"{g2p_path} or {g2p_alt}. MFA will run without " + f"G2P — OOV words may fail alignment." + ) + + logger.info(f"Running MFA align: {' '.join(cmd)}") + + result = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + + if result.stdout and result.stdout.strip(): + logger.info( + f"MFA stdout (last 5000 chars):\n{result.stdout[-5000:]}" + ) + if result.stderr and result.stderr.strip(): + logger.warning( + f"MFA stderr (last 5000 chars):\n{result.stderr[-5000:]}" + ) + + if result.returncode != 0: + raise RuntimeError( + f"mfa align failed (exit code {result.returncode}).\n" + f"STDOUT:\n{result.stdout}\n" + f"STDERR:\n{result.stderr}" + ) + + def _get_word_alignment_tier(self, tg: Any, textgrid_path: Path) -> Any: # noqa: ANN401 + """Select the word-level tier, avoiding phone-level tiers when possible.""" + for tier_name in _WORD_TIER_NAMES: + if tier_name in tg.tierNames: + return tg.getTier(tier_name) + + if not tg.tierNames: + msg = f"No tiers found in TextGrid: {textgrid_path}" + raise ValueError(msg) + + non_phone_tiers = [ + name + for name in tg.tierNames + if name.lower() not in _PHONE_TIER_NAMES + ] + if non_phone_tiers: + fallback_name = non_phone_tiers[0] + logger.warning( + f"No 'words' tier in {textgrid_path}; " + f"available tiers: {list(tg.tierNames)}. " + f"Using '{fallback_name}'." + ) + return tg.getTier(fallback_name) + + msg = ( + f"No word alignment tier in {textgrid_path}; " + f"available tiers: {list(tg.tierNames)}. " + "Refusing to parse phone-level tiers as words." + ) + raise ValueError(msg) + + def _parse_textgrid_words(self, textgrid_path: Path) -> list[tuple]: + """Return ``[(start, end, label), ...]`` from the word alignment tier.""" + tg = self._textgrid_mod.openTextgrid( + str(textgrid_path), includeEmptyIntervals=False + ) + tier = self._get_word_alignment_tier(tg, textgrid_path) + return [(e.start, e.end, e.label) for e in tier.entries] + + def _is_node_local_mfa_root(self) -> bool: + """True when MFA root is the per-node local copy (safe to mutate).""" + if self.copy_models_to_local: + return True + try: + mfa_root = Path(self._mfa_root).resolve() + local_mfa = ( + Path(self._effective_local_base) + / f"mfa_models_{socket.gethostname()}" + ).resolve() + except OSError: + return False + return mfa_root == local_mfa or local_mfa in mfa_root.parents + + def _textgrid_to_rttm( + self, + textgrid_path: Path, + file_stem: str, + speaker: str, + rttm_path: Path, + ) -> None: + intervals = self._parse_textgrid_words(textgrid_path) + silence = set(self.silence_markers) + speech_intervals: list[dict] = [] + for start, end, label in intervals: + if label.strip() and label.strip() not in silence: + speech_intervals.append( + {"start": start, "duration": end - start} + ) + + merged = self._merge_intervals(speech_intervals) + + with open(rttm_path, "w", encoding="utf-8") as f: + for iv in merged: + f.write( + f"SPEAKER {file_stem} 1 " + f"{iv['start']:.3f} {iv['duration']:.3f} " + f" {speaker} \n" + ) + + def _textgrid_to_ctm( + self, + textgrid_path: Path, + file_stem: str, + ctm_path: Path, + ) -> None: + intervals = self._parse_textgrid_words(textgrid_path) + silence = set(self.silence_markers) + + with open(ctm_path, "w", encoding="utf-8") as f: + for start, end, label in intervals: + word = label.strip() + if word and word not in silence: + f.write( + f"{file_stem} 1 {start:.3f} {end - start:.3f} {word}\n" + ) + + def _merge_intervals(self, intervals: list[dict]) -> list[dict]: + if not intervals: + return [] + sorted_ivs = sorted(intervals, key=lambda x: x["start"]) + merged: list[dict] = [] + cur_start = sorted_ivs[0]["start"] + cur_end = cur_start + sorted_ivs[0]["duration"] + + for iv in sorted_ivs[1:]: + iv_start = iv["start"] + iv_end = iv_start + iv["duration"] + if iv_start - cur_end <= self.max_gap_for_merge: + cur_end = max(cur_end, iv_end) + else: + merged.append( + {"start": cur_start, "duration": cur_end - cur_start} + ) + cur_start = iv_start + cur_end = iv_end + + merged.append({"start": cur_start, "duration": cur_end - cur_start}) + return merged + + @staticmethod + def _get_audio_duration(audio_path: str) -> float: + with sf.SoundFile(audio_path) as f: + return len(f) / f.samplerate + + @staticmethod + def _create_duration_fallback_rttm( + file_stem: str, speaker: str, duration: float, rttm_path: Path + ) -> None: + with open(rttm_path, "w", encoding="utf-8") as f: + f.write( + f"SPEAKER {file_stem} 1 0.000 {duration:.3f} " + f" {speaker} \n" + ) + + @staticmethod + def _create_duration_fallback_ctm( + file_stem: str, text: str, duration: float, ctm_path: Path + ) -> None: + words = text.strip().split() + if not words: + ctm_path.write_text("", encoding="utf-8") + return + word_dur = duration / len(words) + with open(ctm_path, "w", encoding="utf-8") as f: + for i, word in enumerate(words): + f.write( + f"{file_stem} 1 {i * word_dur:.3f} {word_dur:.3f} {word}\n" + ) + + def _setup_local_mfa(self, shared_mfa_root: str, hostname: str) -> str: + local_mfa_root = Path(self._effective_local_base) / f"mfa_models_{hostname}" + + if local_mfa_root.exists(): + has_models = ( + (local_mfa_root / "pretrained_models").exists() + or (local_mfa_root / "extracted_models").exists() + ) + if has_models: + logger.info( + f"Using existing local MFA root: {local_mfa_root}" + ) + return str(local_mfa_root) + + logger.info(f"Copying MFA models to local storage: {local_mfa_root}") + local_mfa_root.mkdir(parents=True, exist_ok=True) + + src = Path(shared_mfa_root) + for subdir in ("pretrained_models", "extracted_models"): + src_path = src / subdir + dst_path = local_mfa_root / subdir + if src_path.exists() and not dst_path.exists(): + logger.info(f" Copying {subdir}...") + shutil.copytree(src_path, dst_path) + + logger.info(f"Local MFA setup complete: {local_mfa_root}") + return str(local_mfa_root) diff --git a/pyproject.toml b/pyproject.toml index 707c6e0b67..2a698bdc7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ audio_common = [ "librosa", "scipy", "pydub>=0.25.1", + "praatio>=6.0", "transformers", "accelerate", "pyannote-audio>=4.0.0; platform_machine == 'x86_64' and platform_system != 'Darwin'", diff --git a/tests/stages/audio/alignment/__init__.py b/tests/stages/audio/alignment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/stages/audio/alignment/test_mfa_alignment.py b/tests/stages/audio/alignment/test_mfa_alignment.py new file mode 100644 index 0000000000..c207ce6329 --- /dev/null +++ b/tests/stages/audio/alignment/test_mfa_alignment.py @@ -0,0 +1,424 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for MFAAlignmentStage.""" + +from __future__ import annotations + +import subprocess +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from nemo_curator.stages.audio.alignment.mfa_alignment import MFAAlignmentStage +from nemo_curator.tasks import AudioTask + +MODULE = "nemo_curator.stages.audio.alignment.mfa_alignment" + + +def _make_stage(tmp_path: Path, **overrides: object) -> MFAAlignmentStage: + defaults: dict[str, object] = { + "output_dir": str(tmp_path / "output"), + "mfa_root_dir": str(tmp_path / "mfa_root"), + "copy_models_to_local": False, + } + defaults.update(overrides) + return MFAAlignmentStage(**defaults) # type: ignore[arg-type] + + +def _make_wav(tmp_path: Path, name: str = "sample.wav") -> Path: + wav = tmp_path / name + wav.write_bytes(b"RIFF" + b"\x00" * 100) + return wav + + +def _make_task( + wav: Path, + text: str = "hello world", + *, + text_key: str = "text", + **extra: object, +) -> AudioTask: + data: dict[str, object] = { + "audio_filepath": str(wav), + text_key: text, + "speaker": "spk1", + "duration": 1.0, + **extra, + } + return AudioTask(data=data) + + +def _fake_textgrid_entry(start: float, end: float, label: str) -> SimpleNamespace: + return SimpleNamespace(start=start, end=end, label=label) + + +def _fake_tier(entries: list[SimpleNamespace]) -> SimpleNamespace: + return SimpleNamespace(entries=entries) + + +def _fake_textgrid( + entries: list[SimpleNamespace], tier_name: str = "words" +) -> SimpleNamespace: + tier = _fake_tier(entries) + return SimpleNamespace( + tierNames=[tier_name], + getTier=lambda _name: tier, # noqa: ARG005 + ) + + +def _fake_textgrid_multi(tier_entries: dict[str, list[SimpleNamespace]]) -> SimpleNamespace: + tiers = {name: _fake_tier(entries) for name, entries in tier_entries.items()} + return SimpleNamespace( + tierNames=list(tier_entries.keys()), + getTier=lambda name: tiers[name], + ) + + +def _align_textgrid_output_dir(cmd: list[str]) -> Path: + align_idx = cmd.index("align") + return Path(cmd[align_idx + 4]) + + +def _setup_stage( + stage: MFAAlignmentStage, + *, + textgrid: SimpleNamespace | None = None, +) -> MagicMock: + """Run setup() and inject a mock TextGrid parser.""" + stage.setup() + fake_tg_mod = MagicMock() + if textgrid is not None: + fake_tg_mod.openTextgrid.return_value = textgrid + stage._textgrid_mod = fake_tg_mod + return fake_tg_mod + + +def _mock_mfa_writes_textgrid(wav: Path): + def _run(cmd: list[str], **kwargs: object) -> subprocess.CompletedProcess: # noqa: ARG001 + tg_dir = _align_textgrid_output_dir(cmd) + (tg_dir / f"{wav.stem}.TextGrid").write_text("fake textgrid") + return subprocess.CompletedProcess(cmd, returncode=0, stdout="", stderr="") + + return _run + + +class TestMFAAlignmentStage: + """Test suite for MFAAlignmentStage.""" + + def test_outputs_reflect_create_flags(self, tmp_path: Path) -> None: + _, data_no_rttm = _make_stage(tmp_path, create_rttm=False).outputs() + assert "rttm_filepath" not in data_no_rttm + assert "ctm_filepath" in data_no_rttm + + _, data_no_ctm = _make_stage(tmp_path, create_ctm=False).outputs() + assert "rttm_filepath" in data_no_ctm + assert "ctm_filepath" not in data_no_ctm + + _, data_tg_only = _make_stage( + tmp_path, create_rttm=False, create_ctm=False + ).outputs() + assert data_tg_only == ["textgrid_filepath"] + + def test_process_batch_empty(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + _setup_stage(stage) + assert stage.process_batch([]) == [] + + def test_process_batch_success(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path) + entries = [ + _fake_textgrid_entry(0.0, 0.5, "hello"), + _fake_textgrid_entry(0.5, 1.0, "world"), + ] + _setup_stage(stage, textgrid=_fake_textgrid(entries)) + task = _make_task(wav) + + with patch(f"{MODULE}.subprocess.run", side_effect=_mock_mfa_writes_textgrid(wav)): + results = stage.process_batch([task]) + + assert len(results) == 1 + assert "textgrid_filepath" in results[0].data + assert "rttm_filepath" in results[0].data + assert "ctm_filepath" in results[0].data + assert Path(results[0].data["rttm_filepath"]).exists() + assert Path(results[0].data["ctm_filepath"]).exists() + + def test_process_batch_mfa_failure_raises(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path) + _setup_stage(stage) + task = _make_task(wav) + + failed = subprocess.CompletedProcess( + ["mfa"], returncode=1, stdout="error out", stderr="error err" + ) + with ( + patch(f"{MODULE}.subprocess.run", return_value=failed), + pytest.raises(RuntimeError, match="mfa align failed"), + ): + stage.process_batch([task]) + + def test_process_batch_missing_textgrid_fallback(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path) + _setup_stage(stage) + task = _make_task(wav, duration=2.0) + + ok = subprocess.CompletedProcess(["mfa"], returncode=0, stdout="", stderr="") + with patch(f"{MODULE}.subprocess.run", return_value=ok): + results = stage.process_batch([task]) + + assert len(results) == 1 + assert results[0].data.get("mfa_skipped") is True + assert results[0].data["textgrid_filepath"] == "" + assert Path(results[0].data["rttm_filepath"]).exists() + assert Path(results[0].data["ctm_filepath"]).exists() + ctm_lines = Path(results[0].data["ctm_filepath"]).read_text().strip().split("\n") + assert len(ctm_lines) == 2 + assert "hello" in ctm_lines[0] + assert "world" in ctm_lines[1] + + def test_process_batch_create_rttm_false(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path, create_rttm=False) + _setup_stage(stage, textgrid=_fake_textgrid([_fake_textgrid_entry(0.0, 1.0, "hello")])) + task = _make_task(wav, text="hello") + + with patch(f"{MODULE}.subprocess.run", side_effect=_mock_mfa_writes_textgrid(wav)): + results = stage.process_batch([task]) + + assert "rttm_filepath" not in results[0].data + assert "ctm_filepath" in results[0].data + assert "textgrid_filepath" in results[0].data + + def test_process_batch_create_ctm_false(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path, create_ctm=False) + _setup_stage(stage, textgrid=_fake_textgrid([_fake_textgrid_entry(0.0, 1.0, "hello")])) + task = _make_task(wav, text="hello") + + with patch(f"{MODULE}.subprocess.run", side_effect=_mock_mfa_writes_textgrid(wav)): + results = stage.process_batch([task]) + + assert "rttm_filepath" in results[0].data + assert "ctm_filepath" not in results[0].data + + def test_process_batch_textgrid_only(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path, create_rttm=False, create_ctm=False) + _setup_stage(stage) + task = _make_task(wav, text="hello") + + with patch(f"{MODULE}.subprocess.run", side_effect=_mock_mfa_writes_textgrid(wav)): + results = stage.process_batch([task]) + + assert "textgrid_filepath" in results[0].data + assert "rttm_filepath" not in results[0].data + assert "ctm_filepath" not in results[0].data + + def test_process_batch_prefers_words_tier_over_phones(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path) + textgrid = _fake_textgrid_multi({ + "phones": [_fake_textgrid_entry(0.0, 0.1, "AH")], + "words": [_fake_textgrid_entry(0.1, 0.5, "hello")], + }) + _setup_stage(stage, textgrid=textgrid) + task = _make_task(wav, text="hello") + + with patch(f"{MODULE}.subprocess.run", side_effect=_mock_mfa_writes_textgrid(wav)): + results = stage.process_batch([task]) + + ctm_words = [ + line.split()[-1] + for line in Path(results[0].data["ctm_filepath"]).read_text().strip().split("\n") + if line + ] + assert ctm_words == ["hello"] + + def test_process_batch_raises_when_only_phone_tiers(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path) + textgrid = _fake_textgrid_multi({ + "phones": [_fake_textgrid_entry(0.0, 0.1, "AH")], + }) + _setup_stage(stage, textgrid=textgrid) + task = _make_task(wav, text="hello") + + with ( + patch(f"{MODULE}.subprocess.run", side_effect=_mock_mfa_writes_textgrid(wav)), + pytest.raises(ValueError, match="Refusing to parse phone-level tiers"), + ): + stage.process_batch([task]) + + def test_process_batch_filters_silence_markers(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path) + entries = [ + _fake_textgrid_entry(0.0, 0.1, "sp"), + _fake_textgrid_entry(0.1, 0.3, "hello"), + _fake_textgrid_entry(0.3, 0.4, "sil"), + _fake_textgrid_entry(0.4, 0.6, "world"), + _fake_textgrid_entry(0.6, 0.7, ""), + ] + _setup_stage(stage, textgrid=_fake_textgrid(entries)) + task = _make_task(wav) + + with patch(f"{MODULE}.subprocess.run", side_effect=_mock_mfa_writes_textgrid(wav)): + results = stage.process_batch([task]) + + ctm_words = [ + line.split()[-1] + for line in Path(results[0].data["ctm_filepath"]).read_text().strip().split("\n") + if line + ] + assert ctm_words == ["hello", "world"] + + def test_process_batch_custom_silence_markers(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path, silence_markers=("", "PAUSE")) + entries = [ + _fake_textgrid_entry(0.0, 0.2, "PAUSE"), + _fake_textgrid_entry(0.2, 0.5, "sp"), + _fake_textgrid_entry(0.5, 0.8, "hello"), + ] + _setup_stage(stage, textgrid=_fake_textgrid(entries)) + task = _make_task(wav, text="hello") + + with patch(f"{MODULE}.subprocess.run", side_effect=_mock_mfa_writes_textgrid(wav)): + results = stage.process_batch([task]) + + ctm_words = [ + line.split()[-1] + for line in Path(results[0].data["ctm_filepath"]).read_text().strip().split("\n") + if line + ] + assert "PAUSE" not in ctm_words + assert "sp" in ctm_words + assert "hello" in ctm_words + + def test_custom_text_key(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage(tmp_path, text_key="utterance") + _setup_stage(stage, textgrid=_fake_textgrid([_fake_textgrid_entry(0.0, 1.0, "hello")])) + task = _make_task(wav, text="hello", text_key="utterance") + + _, data_keys = stage.inputs() + assert "utterance" in data_keys + + with patch(f"{MODULE}.subprocess.run", side_effect=_mock_mfa_writes_textgrid(wav)): + results = stage.process_batch([task]) + + assert len(results) == 1 + assert "textgrid_filepath" in results[0].data + + def test_mfa_command_construction(self, tmp_path: Path) -> None: + wav = _make_wav(tmp_path) + stage = _make_stage( + tmp_path, + mfa_command="conda run -n mfa mfa", + beam=200, + retry_beam=800, + single_speaker=False, + clean=False, + use_mp=False, + output_format="short_textgrid", + ) + _setup_stage(stage, textgrid=_fake_textgrid([_fake_textgrid_entry(0.0, 1.0, "test")])) + captured_cmd: list[str] = [] + + def capture_run(cmd: list[str], **kwargs: object) -> subprocess.CompletedProcess: # noqa: ARG001 + captured_cmd.extend(cmd) + tg_dir = _align_textgrid_output_dir(cmd) + (tg_dir / f"{wav.stem}.TextGrid").write_text("fake") + return subprocess.CompletedProcess(cmd, returncode=0, stdout="", stderr="") + + task = _make_task(wav, text="test") + with patch(f"{MODULE}.subprocess.run", side_effect=capture_run): + stage.process_batch([task]) + + assert "align" in captured_cmd + assert "--beam" in captured_cmd and "200" in captured_cmd + assert "--retry_beam" in captured_cmd and "800" in captured_cmd + assert "--output_format" in captured_cmd and "short_textgrid" in captured_cmd + assert "--single_speaker" not in captured_cmd + assert "--clean" not in captured_cmd + assert "--use_mp" not in captured_cmd + + def test_setup_on_node_copies_models(self, tmp_path: Path) -> None: + shared_root = tmp_path / "shared_mfa" + (shared_root / "pretrained_models").mkdir(parents=True) + (shared_root / "pretrained_models" / "model.bin").write_bytes(b"data") + (shared_root / "extracted_models").mkdir(parents=True) + (shared_root / "extracted_models" / "ext.bin").write_bytes(b"data") + + stage = _make_stage( + tmp_path, + mfa_root_dir=str(shared_root), + local_mfa_base_dir=str(tmp_path / "local"), + copy_models_to_local=True, + ) + stage.setup_on_node() + + local_root = Path(stage._mfa_root) + assert local_root.exists() + assert (local_root / "pretrained_models" / "model.bin").exists() + assert (local_root / "extracted_models" / "ext.bin").exists() + + def test_setup_on_node_reuses_existing_local_root(self, tmp_path: Path) -> None: + import socket + + local_base = tmp_path / "local" + local_mfa = local_base / f"mfa_models_{socket.gethostname()}" + (local_mfa / "pretrained_models").mkdir(parents=True) + + stage = _make_stage( + tmp_path, + local_mfa_base_dir=str(local_base), + copy_models_to_local=True, + ) + stage.setup_on_node() + assert stage._mfa_root == str(local_mfa) + + def test_shared_mfa_root_does_not_delete_command_history(self, tmp_path: Path) -> None: + shared_root = tmp_path / "shared_mfa" + shared_root.mkdir() + history = shared_root / "command_history.yaml" + history.write_text("history: []\n") + + stage = _make_stage( + tmp_path, + mfa_root_dir=str(shared_root), + copy_models_to_local=False, + ) + stage._mfa_root = str(shared_root) + assert stage._is_node_local_mfa_root() is False + + corpus = tmp_path / "corpus" + corpus.mkdir() + tg_out = tmp_path / "tg_out" + tg_out.mkdir() + + with patch( + f"{MODULE}.subprocess.run", + return_value=subprocess.CompletedProcess([], 0, "", ""), + ): + stage._run_mfa_align(corpus, tg_out) + + assert history.exists() + assert history.read_text() == "history: []\n" diff --git a/tutorials/audio/README.md b/tutorials/audio/README.md index 928f5d359d..067999ab2c 100644 --- a/tutorials/audio/README.md +++ b/tutorials/audio/README.md @@ -23,6 +23,7 @@ sudo apt-get install -y ffmpeg | **[FLEURS Dataset](fleurs/)** | Complete pipeline for multilingual speech data | `pipeline.py`, `run.py`, `pipeline.yaml` | | **[Audio Tagging](tagging/)** | Label raw audio for TTS/ASR via diarization, alignment, and quality metrics | `main.py`, `tts_pipeline.yaml`, `asr_pipeline.yaml` | | **[ALM Data Pipeline](alm/)** | Create training windows for Audio Language Models | `main.py`, `pipeline.yaml` | +| **[MFA Forced Alignment](alignment/)** | Word-level alignment with Montreal Forced Aligner | `pipeline.py`, `run.py`, `pipeline.yaml` | ## Documentation Links diff --git a/tutorials/audio/alignment/README.md b/tutorials/audio/alignment/README.md new file mode 100644 index 0000000000..2727080b79 --- /dev/null +++ b/tutorials/audio/alignment/README.md @@ -0,0 +1,272 @@ +# MFA Forced Alignment Pipeline + +Forced alignment of audio with transcripts using the [Montreal Forced Aligner (MFA)](https://montreal-forced-aligner.readthedocs.io/). + +This pipeline takes a JSONL audio manifest (audio files + transcripts), runs MFA batch alignment, and produces word-level TextGrid files with optional RTTM (speech activity) and CTM (word timing) outputs. + +## What is MFA? + +Montreal Forced Aligner is a tool that aligns orthographic transcriptions to audio recordings, producing **word-level** and **phone-level** time boundaries stored in Praat TextGrid files. + +The `MFAAlignmentStage` wraps MFA as a NeMo Curator processing stage, enabling: + +- **Batch alignment** -- groups of audio files are aligned in a single `mfa align` call for efficiency +- **TextGrid output** -- the native MFA alignment format +- **RTTM output** -- speech activity segments derived from word boundaries (useful for diarization pipelines) +- **CTM output** -- word-level timing in NIST CTM format (useful for ASR evaluation) + +## Prerequisites + +### 1. Install NeMo Curator with alignment dependencies + +```bash +uv sync --extra audio_cuda12 +``` + +This installs `praatio` (for TextGrid parsing) and other audio dependencies via `audio_common`. + +### 2. Install Montreal Forced Aligner + +MFA is distributed via conda/micromamba (not pip). Install it in a separate environment: + +```bash +# Using micromamba +micromamba create -n mfa -c conda-forge montreal-forced-aligner +micromamba activate mfa + +# Or using conda +conda create -n mfa -c conda-forge montreal-forced-aligner +conda activate mfa +``` + +If MFA is in a separate conda environment, provide the full path to the binary via `--mfa-command`: + +```bash +--mfa-command /path/to/micromamba/envs/mfa/bin/mfa +``` + +### 3. Download MFA models + +```bash +# Acoustic model + pronunciation dictionary (English example) +mfa model download acoustic english_us_arpa +mfa model download dictionary english_us_arpa + +# Optional: G2P model for out-of-vocabulary words +mfa model download g2p english_us_arpa +``` + +Models are stored under `~/.mfa/pretrained_models/` by default. Override with `--mfa-root-dir` or the `MFA_ROOT_DIR` environment variable. + +## Quick Start + +```bash +# Basic alignment with RTTM + CTM output +python tutorials/audio/alignment/pipeline.py \ + --input-manifest /data/manifest.jsonl \ + --output-dir /data/aligned + +# TextGrid-only output (no RTTM/CTM conversion) +python tutorials/audio/alignment/pipeline.py \ + --input-manifest /data/manifest.jsonl \ + --output-dir /data/aligned \ + --no-rttm --no-ctm + +# Custom MFA binary and models +python tutorials/audio/alignment/pipeline.py \ + --input-manifest /data/manifest.jsonl \ + --output-dir /data/aligned \ + --mfa-command /opt/micromamba/envs/mfa/bin/mfa \ + --mfa-root-dir /shared/mfa_models \ + --acoustic-model english_us_arpa \ + --dictionary english_us_arpa +``` + +## Input Format + +The pipeline expects a JSONL manifest where each line is a JSON object with at least: + +```json +{"audio_filepath": "/data/audio/utt001.wav", "text": "hello world", "speaker": "speaker_a"} +``` + +| Key | Required | Description | +|-----|----------|-------------| +| `audio_filepath` | Yes | Path to the WAV audio file | +| `text` | Yes | Transcript text for alignment | +| `speaker` | No | Speaker label (used in RTTM output; defaults to `"unknown"`) | +| `duration` | No | Audio duration in seconds (computed automatically if missing) | + +The key names are configurable via `--text-key`, `--audio-filepath-key`, and `--speaker-key`. + +## Pipeline Architecture + +``` +Input JSONL Manifest + | + v +MFAAlignmentStage (process_batch) + |-- Prepares temporary corpus (symlinked WAVs + .txt files) + |-- Runs single `mfa align` subprocess + |-- Parses resulting TextGrid files + |-- Converts to RTTM (if create_rttm=True) + |-- Converts to CTM (if create_ctm=True) + |-- Adds output paths to task.data + | + v +AudioToDocumentStage + | + v +JsonlWriter -> Output JSONL Manifest +``` + +## CLI Options + +| Argument | Default | Description | +|----------|---------|-------------| +| `--input-manifest` | *required* | Path to input JSONL manifest | +| `--output-dir` | *required* | Root output directory | +| `--mfa-command` | `mfa` | Path to the MFA binary | +| `--mfa-root-dir` | `~/.mfa` | MFA root directory with pretrained models | +| `--acoustic-model` | `english_us_arpa` | MFA acoustic model name or path | +| `--dictionary` | `english_us_arpa` | MFA dictionary name or path | +| `--g2p-model` | `english_us_arpa` | MFA G2P model (empty string to disable) | +| `--text-key` | `text` | Manifest key for transcript text | +| `--audio-filepath-key` | `audio_filepath` | Manifest key for audio file path | +| `--speaker-key` | `speaker` | Manifest key for speaker label | +| `--beam` | `100` | MFA beam size | +| `--retry-beam` | `400` | MFA retry beam for failed alignments | +| `--num-jobs` | `1` | Parallel MFA jobs passed to MFA ``-j`` | +| `--batch-size` | `256` | Files per `mfa align` invocation | +| `--no-rttm` | `false` | Skip RTTM generation | +| `--no-ctm` | `false` | Skip CTM generation | +| `--backend` | `ray_data` | Execution backend (`ray_data` or `xenna`) | +| `--clean` | `false` | Overwrite existing result directory | +| `--verbose` | `false` | Enable DEBUG logging | + +## Output Format + +The output manifest JSONL contains all original fields plus: + +```json +{ + "audio_filepath": "/data/audio/utt001.wav", + "text": "hello world", + "speaker": "speaker_a", + "duration": 1.23, + "textgrid_filepath": "/data/aligned/textgrids/abc123/utt001.TextGrid", + "rttm_filepath": "/data/aligned/rttms/utt001.rttm", + "ctm_filepath": "/data/aligned/ctms/utt001.ctm" +} +``` + +### Output directory structure + +``` +output_dir/ +├── textgrids/ # MFA TextGrid alignments (subdirs per batch) +│ └── / +│ └── utt001.TextGrid +├── rttms/ # RTTM speech activity files (if enabled) +│ └── utt001.rttm +├── ctms/ # CTM word timing files (if enabled) +│ └── utt001.ctm +└── result/ # Output JSONL manifest + └── *.jsonl +``` + +### RTTM format + +``` +SPEAKER utt001 1 0.120 0.890 speaker_a +``` + +Fields: `SPEAKER ` + +Nearby speech intervals are merged when separated by less than `max_gap_for_merge` seconds (default 0.3s). + +### CTM format + +``` +utt001 1 0.120 0.380 hello +utt001 1 0.510 0.390 world +``` + +Fields: ` ` + +## Using with Hydra + +```bash +python tutorials/audio/alignment/run.py \ + --config-path=. --config-name=pipeline \ + output_dir=/data/aligned +``` + +See `pipeline.yaml` for all configurable parameters. Override any field from the command line: + +```bash +python tutorials/audio/alignment/run.py \ + --config-path=. --config-name=pipeline \ + output_dir=/data/aligned \ + processors.0.acoustic_model=english_mfa \ + processors.0.dictionary=english_mfa \ + batch_size=512 +``` + +## Multi-Node / Distributed Execution + +When running on multiple nodes (e.g., via Xenna or Ray cluster), `MFAAlignmentStage` handles distributed MFA gracefully: + +- **`setup_on_node()`** copies MFA pretrained models from shared storage (NFS/Lustre) to each node's local storage (e.g., `/tmp`). This avoids file-locking issues that Kaldi (used internally by MFA) has with network filesystems. +- **`xenna_stage_spec()`** requests exactly 1 MFA worker per node, since MFA itself uses internal parallelism via ``num_jobs`` (MFA ``-j``). +- Set `copy_models_to_local=False` if MFA models are already on local storage. + +## Non-English Languages + +MFA supports [many languages](https://mfa-models.readthedocs.io/en/latest/). To align a different language: + +1. Download the appropriate models: + +```bash +mfa model download acoustic german_mfa +mfa model download dictionary german_mfa +mfa model download g2p german_mfa +``` + +2. Pass them to the pipeline: + +```bash +python tutorials/audio/alignment/pipeline.py \ + --input-manifest /data/german_manifest.jsonl \ + --output-dir /data/aligned_de \ + --acoustic-model german_mfa \ + --dictionary german_mfa \ + --g2p-model german_mfa +``` + +## MFA-Skipped Files + +MFA may silently skip files it cannot align (out-of-vocabulary words, acoustic mismatch, very short audio, etc.). When this happens: + +- The stage creates **fallback** RTTM/CTM files (duration-based: one segment spanning the full audio) +- The entry is marked with `"mfa_skipped": true` in the output manifest +- `"textgrid_filepath"` is set to an empty string + +You can filter these entries downstream or audit them separately. + +## Troubleshooting + +| Issue | Solution | +|-------|----------| +| `mfa: command not found` | Provide the full path via `--mfa-command /path/to/mfa` | +| `praatio` import error | Run `uv sync --extra audio_cuda12` (or `audio_cpu`) from the Curator repo root | +| `Kaldi error: cannot lock file` | Enable `copy_models_to_local=True` (default) or use local storage for `--mfa-root-dir` | +| Many files silently skipped | Check for OOV words; provide a G2P model or expand the dictionary | +| `mfa align` OOM | Reduce `--batch-size` to process fewer files per invocation | +| Slow alignment | Increase `--num-jobs` or ensure MFA has access to all CPU cores | + +## License + +This tutorial and the `MFAAlignmentStage` are licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0). + +MFA itself is licensed under the [MIT License](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/blob/main/LICENSE). diff --git a/tutorials/audio/alignment/pipeline.py b/tutorials/audio/alignment/pipeline.py new file mode 100644 index 0000000000..ddb58c4b81 --- /dev/null +++ b/tutorials/audio/alignment/pipeline.py @@ -0,0 +1,231 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MFA Forced Alignment Pipeline +============================== + +Reads a JSONL audio manifest, runs Montreal Forced Aligner (MFA) on each +batch of entries, and writes the enriched manifest with TextGrid, RTTM, +and CTM file paths back to JSONL. + +Example +------- +:: + + python pipeline.py \\ + --input-manifest /data/manifest.jsonl \\ + --output-dir /data/aligned \\ + --acoustic-model english_us_arpa \\ + --dictionary english_us_arpa +""" + +import argparse +import os +import shutil +import sys + +from loguru import logger + +from nemo_curator.backends.ray_data import RayDataExecutor +from nemo_curator.backends.xenna import XennaExecutor +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.audio.alignment import MFAAlignmentStage +from nemo_curator.stages.audio.common import ManifestReader +from nemo_curator.stages.audio.io.convert import AudioToDocumentStage +from nemo_curator.stages.text.io.writer import JsonlWriter + + +def create_pipeline(args: argparse.Namespace) -> Pipeline: + pipeline = Pipeline( + name="mfa_alignment", + description="Forced alignment with Montreal Forced Aligner", + ) + + pipeline.add_stage(ManifestReader(manifest_path=args.input_manifest)) + + pipeline.add_stage( + MFAAlignmentStage( + output_dir=args.output_dir, + mfa_command=args.mfa_command, + acoustic_model=args.acoustic_model, + dictionary=args.dictionary, + g2p_model=args.g2p_model, + audio_filepath_key=args.audio_filepath_key, + text_key=args.text_key, + speaker_key=args.speaker_key, + num_jobs=args.num_jobs, + beam=args.beam, + retry_beam=args.retry_beam, + create_rttm=not args.no_rttm, + create_ctm=not args.no_ctm, + mfa_root_dir=args.mfa_root_dir, + ).with_(batch_size=args.batch_size) + ) + + pipeline.add_stage(AudioToDocumentStage().with_(batch_size=1)) + + result_dir = os.path.join(args.output_dir, "result") + if args.clean and os.path.isdir(result_dir): + shutil.rmtree(result_dir) + elif not args.clean and os.path.exists(result_dir): + msg = f"Result directory {result_dir} already exists. Use --clean to overwrite." + raise ValueError(msg) + + pipeline.add_stage( + JsonlWriter( + path=result_dir, + write_kwargs={"force_ascii": False}, + ) + ) + + return pipeline + + +def main(args: argparse.Namespace) -> None: + logger.remove() + logger.add(sys.stderr, level="DEBUG" if args.verbose else "INFO") + + pipeline = create_pipeline(args) + logger.info(pipeline.describe()) + logger.info("\n" + "=" * 50 + "\n") + + executor = ( + RayDataExecutor() if args.backend == "ray_data" else XennaExecutor() + ) + + logger.info("Starting MFA alignment pipeline...") + pipeline.run(executor) + logger.info("\nPipeline completed!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="MFA forced alignment pipeline for audio manifests", + ) + + parser.add_argument( + "--input-manifest", + type=str, + required=True, + help="Path to input JSONL manifest", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Root output directory for TextGrids, RTTMs, CTMs, and result manifest", + ) + parser.add_argument( + "--mfa-command", + type=str, + default="mfa", + help="Path to the mfa binary (default: mfa)", + ) + parser.add_argument( + "--mfa-root-dir", + type=str, + default="", + help="MFA root directory with models (default: MFA_ROOT_DIR env or ~/.mfa)", + ) + parser.add_argument( + "--acoustic-model", + type=str, + default="english_us_arpa", + help="MFA acoustic model name or path", + ) + parser.add_argument( + "--dictionary", + type=str, + default="english_us_arpa", + help="MFA dictionary name or path", + ) + parser.add_argument( + "--g2p-model", + type=str, + default="english_us_arpa", + help="MFA G2P model for OOV words (set empty to disable)", + ) + parser.add_argument( + "--text-key", + type=str, + default="text", + help="Key in manifest entries for transcript text", + ) + parser.add_argument( + "--audio-filepath-key", + type=str, + default="audio_filepath", + help="Key in manifest entries for audio file path", + ) + parser.add_argument( + "--speaker-key", + type=str, + default="speaker", + help="Key in manifest entries for speaker label", + ) + parser.add_argument( + "--beam", + type=int, + default=100, + help="MFA beam size for alignment search", + ) + parser.add_argument( + "--retry-beam", + type=int, + default=400, + help="MFA retry beam size for failed alignments", + ) + parser.add_argument( + "--num-jobs", + type=int, + default=1, + help="Number of parallel MFA jobs passed to MFA -j", + ) + parser.add_argument( + "--batch-size", + type=int, + default=256, + help="Number of audio files per MFA alignment batch", + ) + parser.add_argument( + "--no-rttm", + action="store_true", + help="Skip RTTM generation (only produce TextGrids)", + ) + parser.add_argument( + "--no-ctm", + action="store_true", + help="Skip CTM generation (only produce TextGrids)", + ) + parser.add_argument( + "--clean", + action="store_true", + help="Delete existing result directory before writing outputs", + ) + parser.add_argument( + "--backend", + type=str, + choices=["xenna", "ray_data"], + default="ray_data", + help="Execution backend: 'ray_data' (default) or 'xenna'", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose (DEBUG) logging", + ) + + args = parser.parse_args() + main(args) diff --git a/tutorials/audio/alignment/pipeline.yaml b/tutorials/audio/alignment/pipeline.yaml new file mode 100644 index 0000000000..e9493a3e77 --- /dev/null +++ b/tutorials/audio/alignment/pipeline.yaml @@ -0,0 +1,79 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - _self_ + - override hydra/job_logging: none + - override hydra/hydra_logging: none + +hydra: + run: + dir: . + output_subdir: null + +documentation: | + MFA Forced Alignment + #################### + This config runs Montreal Forced Aligner (MFA) on an audio manifest to + produce word-level TextGrid alignments, with optional RTTM and CTM output. + + **Required arguments**. + + * **input_manifest**: Path to the input JSONL manifest with ``audio_filepath`` and ``text`` keys. + * **output_dir**: Root directory for TextGrid/RTTM/CTM output and result manifest. + + The pipeline reads AudioTask entries, aligns them with MFA, converts + TextGrids to RTTM and CTM files, then writes the enriched manifest to JSONL. + + Note that you can customize any part of this config either directly or + from the command-line. + + **Output format** + + Output manifest contains the following additional keys: + + * **textgrid_filepath (str)**: path to the TextGrid alignment file. + * **rttm_filepath (str)**: path to the RTTM speech-activity file (if enabled). + * **ctm_filepath (str)**: path to the CTM word-timing file (if enabled). + +input_manifest: ??? +output_dir: ??? +backend: xenna +batch_size: 256 + +processors: + - _target_: nemo_curator.stages.audio.common.ManifestReader + manifest_path: ${input_manifest} + + - _target_: nemo_curator.stages.audio.alignment.MFAAlignmentStage + output_dir: ${output_dir} + mfa_command: "mfa" + acoustic_model: "english_us_arpa" + dictionary: "english_us_arpa" + g2p_model: "english_us_arpa" + audio_filepath_key: "audio_filepath" + text_key: "text" + speaker_key: "speaker" + num_jobs: 1 + beam: 100 + retry_beam: 400 + create_rttm: true + create_ctm: true + + - _target_: nemo_curator.stages.audio.io.convert.AudioToDocumentStage + + - _target_: nemo_curator.stages.text.io.writer.JsonlWriter + path: ${output_dir}/result + write_kwargs: + "force_ascii": false diff --git a/tutorials/audio/alignment/run.py b/tutorials/audio/alignment/run.py new file mode 100644 index 0000000000..710da0af19 --- /dev/null +++ b/tutorials/audio/alignment/run.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Hydra-based MFA alignment pipeline runner. + +Usage:: + + python run.py --config-path=. --config-name=pipeline +""" + +import importlib + +import hydra +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.audio.alignment import MFAAlignmentStage + +_EXECUTOR_FACTORIES = { + "xenna": "nemo_curator.backends.xenna:XennaExecutor", + "ray_data": "nemo_curator.backends.ray_data:RayDataExecutor", +} + + +def _create_executor(backend: str) -> object: + module_path, class_name = _EXECUTOR_FACTORIES[backend].rsplit(":", 1) + mod = importlib.import_module(module_path) + return getattr(mod, class_name)() + + +def create_pipeline_from_yaml(cfg: DictConfig) -> Pipeline: + pipeline = Pipeline( + name="mfa_alignment", + description="MFA forced alignment pipeline (YAML config)", + ) + batch_size = cfg.get("batch_size", None) + for p in cfg.processors: + stage = hydra.utils.instantiate(p) + if batch_size and isinstance(stage, MFAAlignmentStage): + stage = stage.with_(batch_size=batch_size) + pipeline.add_stage(stage) + return pipeline + + +@hydra.main(version_base=None) +def main(cfg: DictConfig) -> None: + logger.info(f"Hydra config:\n{OmegaConf.to_yaml(cfg)}") + pipeline = create_pipeline_from_yaml(cfg) + + logger.info(pipeline.describe()) + logger.info("\n" + "=" * 50 + "\n") + + backend = cfg.get("backend", "ray_data") + if backend not in _EXECUTOR_FACTORIES: + msg = f"Unknown backend '{backend}'. Choose from: {list(_EXECUTOR_FACTORIES)}" + raise ValueError(msg) + logger.info(f"Using backend: {backend}") + executor = _create_executor(backend) + + logger.info("Starting MFA alignment pipeline...") + pipeline.run(executor) + logger.info("\nPipeline completed!") + + +if __name__ == "__main__": + main()