diff --git a/frontend/src-tauri/src/lib.rs b/frontend/src-tauri/src/lib.rs index a9cee65f..d8d98e47 100644 --- a/frontend/src-tauri/src/lib.rs +++ b/frontend/src-tauri/src/lib.rs @@ -50,7 +50,9 @@ pub fn run() { tts::tts_get_status, tts::tts_download_models, tts::tts_load_models, + tts::tts_chunk_text, tts::tts_synthesize, + tts::tts_synthesize_chunk, tts::tts_unload_models, tts::tts_delete_models, ]) @@ -296,7 +298,9 @@ pub fn run() { tts::tts_get_status, tts::tts_download_models, tts::tts_load_models, + tts::tts_chunk_text, tts::tts_synthesize, + tts::tts_synthesize_chunk, tts::tts_unload_models, tts::tts_delete_models, ]) diff --git a/frontend/src-tauri/src/tts.rs b/frontend/src-tauri/src/tts.rs index 602cd8e0..1eba9b0d 100644 --- a/frontend/src-tauri/src/tts.rs +++ b/frontend/src-tauri/src/tts.rs @@ -1,19 +1,20 @@ -use anyhow::{Context, Result}; +use anyhow::{bail, Context, Result}; use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; use futures_util::StreamExt; use hound::{SampleFormat, WavSpec, WavWriter}; use ndarray::{Array, Array3}; use once_cell::sync::Lazy; -use ort::{session::Session, value::Value}; +use ort::{logging::LogLevel, session::Session, value::Value}; use rand::thread_rng; use rand_distr::{Distribution, Normal}; use regex::Regex; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::fs::{self, File}; -use std::io::{BufReader, Cursor, Write}; +use std::io::{BufReader, Cursor, Read, Write}; use std::path::{Path, PathBuf}; use std::sync::Mutex; +use std::time::Instant; use tauri::{AppHandle, Emitter}; use unicode_normalization::UnicodeNormalization; @@ -53,6 +54,8 @@ static RE_STRIKE: Lazy = Lazy::new(|| Regex::new(r"~~([^~]+)~~").unwrap() static RE_CODE: Lazy = Lazy::new(|| Regex::new(r"`([^`]+)`").unwrap()); static RE_CODEBLOCK: Lazy = Lazy::new(|| Regex::new(r"(?s)```[^`]*```").unwrap()); static RE_HEADER: Lazy = Lazy::new(|| Regex::new(r"(?m)^#{1,6}\s*").unwrap()); +static RE_XML_TAG: Lazy = + Lazy::new(|| Regex::new(r"]*)?>").unwrap()); static RE_EMOJI: Lazy = Lazy::new(|| { Regex::new(r"[\x{1F600}-\x{1F64F}\x{1F300}-\x{1F5FF}\x{1F680}-\x{1F6FF}\x{1F700}-\x{1F77F}\x{1F780}-\x{1F7FF}\x{1F800}-\x{1F8FF}\x{1F900}-\x{1F9FF}\x{1FA00}-\x{1FA6F}\x{1FA70}-\x{1FAFF}\x{2600}-\x{26FF}\x{2700}-\x{27BF}\x{1F1E6}-\x{1F1FF}]+").unwrap() }); @@ -76,74 +79,180 @@ static RE_ENDS_PUNCT: Lazy = Lazy::new(|| { static RE_SENTENCE: Lazy = Lazy::new(|| Regex::new(r"([.!?])\s+").unwrap()); // Pin model downloads to a specific repo revision to ensure integrity and reproducibility. -const HUGGINGFACE_REVISION: &str = "b6856d033f622c63ea29441795be266a1133e227"; -const HUGGINGFACE_BASE_URL: &str = "https://huggingface.co/Supertone/supertonic/resolve"; +const HUGGINGFACE_REVISION: &str = "3cadd1ee6394adea1bd021217a0e650ede09a323"; +const HUGGINGFACE_BASE_URL: &str = "https://huggingface.co/Supertone/supertonic-3/resolve"; +const MODEL_REVISION_FILE: &str = "supertonic_revision.txt"; +const SUPERTONIC3_CACHE_DIR: &str = "supertonic-3"; +const DEFAULT_TTS_LANGUAGE: &str = "en"; +const DEFAULT_VOICE_STYLE: &str = "F2.json"; +const TTS_TOTAL_STEPS: usize = 10; +const TTS_CHUNK_MAX_CHARS: usize = 450; +const SUPERTONIC3_TTS_SPEED: f32 = 1.0; +const LEGACY_TTS_SPEED: f32 = 1.2; +const MIN_TTS_SPEED: f32 = 0.5; +const MAX_TTS_SPEED: f32 = 2.0; + +const AVAILABLE_LANGS: &[&str] = &[ + "en", "ko", "ja", "ar", "bg", "cs", "da", "de", "el", "es", "et", "fi", "fr", "hi", "hr", "hu", + "id", "it", "lt", "lv", "nl", "pl", "pt", "ro", "ru", "sk", "sl", "sv", "tr", "uk", "vi", "na", +]; // (file_name, url_path, expected_size_bytes, expected_sha256_hex) -const MODEL_FILES: &[(&str, &str, u64, &str)] = &[ +const SUPERTONIC3_MODEL_FILES: &[(&str, &str, u64, &str)] = &[ ( "duration_predictor.onnx", "onnx/duration_predictor.onnx", - 1_500_789, - "b861580c56a0cba2a2b82aa697ecb3c5a163c3240c60a0ddfac369d21d054092", + 3_700_147, + "c3eb91414d5ff8a7a239b7fe9e34e7e2bf8a8140d8375ffb14718b1c639325db", ), ( "text_encoder.onnx", "onnx/text_encoder.onnx", - 27_348_373, - "ba0c8ea74aeb5df00d21a89b8d47c71317f47120232e3deef95024dba37dbd88", + 36_416_150, + "c7befd5ea8c3119769e8a6c1486c4edc6a3bc8365c67621c881bbb774b9902ff", ), ( "vector_estimator.onnx", "onnx/vector_estimator.onnx", - 132_471_364, - "b3f82ecd2e9decc4e2236048b03628a1c1d5f14a792ba274a59b7325107aa6a6", + 256_534_781, + "883ac868ea0275ef0e991524dc64f16b3c0376efd7c320af6b53f5b780d7c61c", ), ( "vocoder.onnx", "onnx/vocoder.onnx", - 101_405_066, - "19bd51f47a186069c752403518a40f7ea4c647455056d2511f7249691ecddf7c", + 101_424_195, + "085de76dd8e8d5836d6ca66826601f615939218f90e519f70ee8a36ed2a4c4ba", ), ( "tts.json", "onnx/tts.json", - 8_645, - "4dac5f986698a3ace9a97ea2545d43f6c8ba120d25e005f8c905128281be9b6d", + 8_253, + "42078d3aef1cd43ab43021f3c54f47d2d75ceb4e75f627f118890128b06a0d09", ), ( "unicode_indexer.json", "onnx/unicode_indexer.json", - 262_134, - "0c3800ba4fb1fc760c9070eb43a0ad5a68279ec165742591a68ea3edca452978", + 277_676, + "9bf7346e43883a81f8645c81224f786d43c5b57f3641f6e7671a7d6c493cb24f", ), ( "F1.json", "voice_styles/F1.json", - 420_622, - "1450bcad84a2790eaf73f85e763dd5bae7c399f55d692c4835cf4f7686b5a10f", + 292_046, + "bbdec6ee00231c2c742ad05483df5334cab3b52fda3ba38e6a07059c4563dbc2", ), ( "F2.json", "voice_styles/F2.json", - 420_905, - "47c8d44445ef8ac8aae8ef5806feca21903483cbd4f1232e405184a40520a549", + 292_423, + "7c722c6a72707b1a77f035d67f0d1351ba187738e06f7683e8c72b1df3477fc6", + ), + ( + "F3.json", + "voice_styles/F3.json", + 290_794, + "12f6ef2573baa2defa1128069cb59f203e3ab67c92af77b42df8a0e3a2f7c6ab", + ), + ( + "F4.json", + "voice_styles/F4.json", + 291_808, + "c2fa764c1225a76dfc3e2c73e8aa4f70d9ee48793860eb34c295fff01c2e032b", + ), + ( + "F5.json", + "voice_styles/F5.json", + 291_479, + "45966e73316415626cf41a7d1c6f3b4c70dbc1ba2bee5c1978ef0ce33244fc8d", ), ( "M1.json", "voice_styles/M1.json", - 421_053, - "273c9ba6582d2e00383d8fbe2f5d660d86e8fba849c91ff695384d1a6e2e02f1", + 291_748, + "e35604687f5d23694b8e91593a93eec0e4eca6c0b02bb8ed69139ab2ea6b0a5b", ), ( "M2.json", "voice_styles/M2.json", - 421_027, - "26898a9ec3de1b5bf8cc3f6cbf41930543ca0403f2201e12aad849691ff315dd", + 292_055, + "b76cbf62bac707c710cf0ae5aba5e31eea1a6339a9734bfae33ab98499534a50", ), + ( + "M3.json", + "voice_styles/M3.json", + 290_198, + "ea1ac35ccb91b0d7ecad533a2fbd0eec10c91513d8951e3b25fbba99954e159b", + ), + ( + "M4.json", + "voice_styles/M4.json", + 291_522, + "ca8eefad4fcd989c9379032ff3e50738adc547eeb5e221b82593a6d7b3bac303", + ), + ( + "M5.json", + "voice_styles/M5.json", + 291_469, + "dd22b92740314321f8ae11c5e87f8dd60d060f15dd3a632b5adf77f471f77af2", + ), +]; + +const SUPERTONIC3_TOTAL_MODEL_SIZE: u64 = 401_276_744; // bytes + +// Supertonic 1 assets were downloaded directly into the root tts_models +// directory. Keep detecting them so existing users can keep read-aloud working +// until they explicitly delete and download Supertonic 3. +const LEGACY_MODEL_FILES: &[(&str, u64)] = &[ + ("duration_predictor.onnx", 1_500_789), + ("text_encoder.onnx", 27_348_373), + ("vector_estimator.onnx", 132_471_364), + ("vocoder.onnx", 101_405_066), + ("tts.json", 8_645), + ("unicode_indexer.json", 262_134), + ("F1.json", 420_622), + ("F2.json", 420_905), + ("M1.json", 421_053), + ("M2.json", 421_027), ]; -const TOTAL_MODEL_SIZE: u64 = 264_679_978; // bytes +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ModelVersion { + Supertonic3, + Legacy, +} + +fn default_tts_speed(model_version: ModelVersion) -> f32 { + match model_version { + ModelVersion::Supertonic3 => SUPERTONIC3_TTS_SPEED, + ModelVersion::Legacy => LEGACY_TTS_SPEED, + } +} + +fn sanitize_tts_speed(speed: f32) -> f32 { + if speed.is_finite() { + speed.clamp(MIN_TTS_SPEED, MAX_TTS_SPEED) + } else { + SUPERTONIC3_TTS_SPEED + } +} + +fn resolve_tts_speed(model_version: ModelVersion, requested_speed: Option) -> f32 { + requested_speed + .map(sanitize_tts_speed) + .unwrap_or_else(|| default_tts_speed(model_version)) +} + +fn tts_trace_enabled() -> bool { + std::env::var("MAPLE_TTS_TRACE") + .map(|value| { + matches!( + value.to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) + .unwrap_or(false) +} fn bytes_to_hex(bytes: &[u8]) -> String { const HEX: &[u8; 16] = b"0123456789abcdef"; @@ -195,11 +304,15 @@ pub struct Style { struct UnicodeProcessor { indexer: Vec, + unknown_token_id: i64, } impl UnicodeProcessor { - fn new(indexer: Vec) -> Self { - UnicodeProcessor { indexer } + fn new(indexer: Vec, unknown_token_id: i64) -> Self { + UnicodeProcessor { + indexer, + unknown_token_id, + } } fn call(&self, text_list: &[String]) -> (Vec>, Array3) { @@ -215,8 +328,7 @@ impl UnicodeProcessor { if val < self.indexer.len() { row[j] = self.indexer[val]; } else { - // Use 0 (padding token) for out-of-vocabulary characters - row[j] = 0; + row[j] = self.unknown_token_id; } } text_ids.push(row); @@ -227,7 +339,11 @@ impl UnicodeProcessor { } } -fn preprocess_text(text: &str) -> String { +fn is_valid_lang(lang: &str) -> bool { + AVAILABLE_LANGS.contains(&lang) +} + +fn normalize_text_for_tts(text: &str) -> String { let mut text: String = text.nfkd().collect(); // Remove markdown formatting (using pre-compiled regexes) @@ -239,6 +355,7 @@ fn preprocess_text(text: &str) -> String { text = RE_CODE.replace_all(&text, "$1").to_string(); text = RE_CODEBLOCK.replace_all(&text, "").to_string(); text = RE_HEADER.replace_all(&text, "").to_string(); + text = RE_XML_TAG.replace_all(&text, " ").to_string(); text = RE_EMOJI.replace_all(&text, "").to_string(); // Replace various dashes and symbols @@ -302,6 +419,19 @@ fn preprocess_text(text: &str) -> String { text } +fn preprocess_text(text: &str, lang: &str) -> Result { + let text = normalize_text_for_tts(text); + if text.is_empty() { + return Ok(text); + } + + if !is_valid_lang(lang) { + bail!("Invalid TTS language: {lang}. Available: {AVAILABLE_LANGS:?}"); + } + + Ok(format!("<{lang}>{text}")) +} + fn length_to_mask(lengths: &[usize], max_len: Option) -> Array3 { let bsz = lengths.len(); let max_len = max_len.unwrap_or_else(|| *lengths.iter().max().unwrap_or(&0)); @@ -391,75 +521,65 @@ fn chunk_text(text: &str, max_len: usize) -> Vec { } static RE_PARA: Lazy = Lazy::new(|| Regex::new(r"\n\s*\n").unwrap()); - let paragraphs: Vec<&str> = RE_PARA.split(text).collect(); let mut chunks = Vec::new(); + let mut current = String::new(); + + fn push_unit(chunks: &mut Vec, current: &mut String, unit: &str, max_len: usize) { + let unit = unit.trim(); + if unit.is_empty() { + return; + } - for para in paragraphs { + if unit.len() > max_len { + for part in split_by_words(unit, max_len) { + push_unit(chunks, current, &part, max_len); + } + return; + } + + if current.is_empty() { + current.push_str(unit); + return; + } + + if current.len() + 1 + unit.len() <= max_len { + current.push(' '); + current.push_str(unit); + } else { + chunks.push(current.trim().to_string()); + current.clear(); + current.push_str(unit); + } + } + + for para in RE_PARA.split(text) { let para = para.trim(); if para.is_empty() { continue; } if para.len() <= max_len { - chunks.push(para.to_string()); + push_unit(&mut chunks, &mut current, para, max_len); continue; } // Split by sentence boundaries, keeping punctuation - let mut current = String::new(); let mut last_end = 0; for m in RE_SENTENCE.find_iter(para) { let sentence = para[last_end..m.start() + 1].trim(); // +1 to include punctuation last_end = m.end(); - if sentence.is_empty() { - continue; - } - - // If single sentence exceeds max_len, split by words - if sentence.len() > max_len { - if !current.is_empty() { - chunks.push(current.trim().to_string()); - current.clear(); - } - chunks.extend(split_by_words(sentence, max_len)); - continue; - } - - if current.len() + sentence.len() + 1 > max_len && !current.is_empty() { - chunks.push(current.trim().to_string()); - current.clear(); - } - - if !current.is_empty() { - current.push(' '); - } - current.push_str(sentence); + push_unit(&mut chunks, &mut current, sentence, max_len); } // Remaining text after last sentence boundary let remaining = para[last_end..].trim(); - if !remaining.is_empty() { - // If remaining exceeds max_len, split by words - if remaining.len() > max_len { - if !current.is_empty() { - chunks.push(current.trim().to_string()); - } - chunks.extend(split_by_words(remaining, max_len)); - } else if current.len() + remaining.len() + 1 > max_len && !current.is_empty() { - chunks.push(current.trim().to_string()); - chunks.push(remaining.to_string()); - } else { - if !current.is_empty() { - current.push(' '); - } - current.push_str(remaining); - chunks.push(current.trim().to_string()); - } - } else if !current.is_empty() { - chunks.push(current.trim().to_string()); - } + push_unit(&mut chunks, &mut current, remaining, max_len); + } + + if !current.is_empty() { + chunks.push(current.trim().to_string()); } if chunks.is_empty() { @@ -472,6 +592,7 @@ fn chunk_text(text: &str, max_len: usize) -> Vec { pub struct TTSState { tts: Option, style: Option