diff --git a/noisekit/dataset.py b/noisekit/dataset.py index edeb7a6..5386e9a 100644 --- a/noisekit/dataset.py +++ b/noisekit/dataset.py @@ -69,6 +69,10 @@ def extract_audio_and_text(sample: dict) -> tuple[np.ndarray, int, str]: else: raise ValueError("Audio sample has neither 'bytes' nor 'path'.") + array = np.asarray(array, dtype=np.float32) + if array.ndim == 2: + array = array.mean(axis=1) # (samples, channels) → (samples,) for mono-only metrics + text = ( sample.get("text") or sample.get("sentence") @@ -76,4 +80,4 @@ def extract_audio_and_text(sample: dict) -> tuple[np.ndarray, int, str]: or sample.get("normalized_text") or "" ) - return np.asarray(array, dtype=np.float32), int(sr), str(text).strip() + return array, int(sr), str(text).strip()