From 98ecaa883fa4f1430c6f405451fcd955d4f82b17 Mon Sep 17 00:00:00 2001 From: Carina Peng Date: Fri, 12 Jun 2026 18:14:22 -0700 Subject: [PATCH 1/7] Speech runner --- Package.swift | 11 + .../speech-runner/SpeechRunnerMain.swift | 256 ++++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift diff --git a/Package.swift b/Package.swift index 3a23d4c..3ba2f25 100644 --- a/Package.swift +++ b/Package.swift @@ -155,6 +155,17 @@ let package = Package( .enableUpcomingFeature("MemberImportVisibility") ] ), + .executableTarget( + name: "speech-runner", + dependencies: [ + "CoreAIShared", + .product(name: "Transformers", package: "swift-transformers"), + ], + path: "swift/Sources/Tools/speech-runner", + swiftSettings: [ + .enableUpcomingFeature("MemberImportVisibility") + ] + ), // Public LLM Benchmark CLI (based on mlx-lm benchmark) .executableTarget( diff --git a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift new file mode 100644 index 0000000..cb3123a --- /dev/null +++ b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift @@ -0,0 +1,256 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import CoreAI +import CoreAIShared +import Foundation +import Tokenizers + +// Whisper forced prefix: <|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> +private let forcedPrefix: [Int32] = [50258, 50259, 50360, 50364] +private let eotToken: Int32 = 50257 +private let maxTargetPositions = 448 +private let maxDecodeSteps = 50 +private let melElements = 128 * 3000 + +// MARK: - Entry point + +// Usage: speech-runner [audio-or-mel] +// +// model-path A bundle dir with encoder.aimodel + decoder.aimodel (--mode coreai), +// or a single .aimodel file (--mode legacy). +// +// audio-or-mel An audio file (wav, flac, m4a, …) or a precomputed mel .bin +// from tools/compute_mel.py. Omit for silence benchmarking. +@main +struct Main { + static func main() async { + guard CommandLine.arguments.count > 1 else { + print("Usage: speech-runner [audio-or-mel]") + exit(1) + } + let modelPath = CommandLine.arguments[1] + let audioPath = CommandLine.arguments.count > 2 ? CommandLine.arguments[2] : nil + do { + let encURL = URL(fileURLWithPath: "\(modelPath)/encoder.aimodel") + if FileManager.default.fileExists(atPath: encURL.path) { + try await runSplit(bundleDir: modelPath, audioPath: audioPath) + } else { + try await runLegacy(modelPath: modelPath, audioPath: audioPath) + } + } catch { + print("Fatal: \(error)") + exit(1) + } + } +} + +// MARK: - Mel loading + +private let audioExtensions: Set = ["wav", "flac", "m4a", "mp3", "aiff", "aif", "caf"] + +private func loadMelArray(from path: String, descriptor: NDArrayDescriptor) throws -> NDArray { + let url = URL(fileURLWithPath: path) + let floats: [Float] + if audioExtensions.contains(url.pathExtension.lowercased()) { + print("Computing mel from audio file…") + floats = try WhisperMel.fromFile(url) + } else { + let data = try Data(contentsOf: url) + let count = data.count / MemoryLayout.size + guard count == melElements else { + fatalError("mel bin has \(count) floats, expected \(melElements) (128×3000)") + } + floats = data.withUnsafeBytes { Array($0.bindMemory(to: Float.self)) } + } + var array = NDArray(descriptor: descriptor.resolvingDynamicDimensions([1, 128, 3000])) + fillNDArray(&array, as: Float.self, with: floats) + return array +} + +// MARK: - Results + +private func printResults(tokens: [Int32], stepTimesMs: [Double]) async { + let avgMs = stepTimesMs.reduce(0, +) / Double(stepTimesMs.count) + print(String(format: " steps: %d", stepTimesMs.count)) + print(String(format: " latency: %.1f ms/tok", avgMs)) + print(String(format: " speed: %.1f tok/s", 1000.0 / avgMs)) + if let lo = stepTimesMs.min(), let hi = stepTimesMs.max() { + print(String(format: " min/max: %.1f / %.1f ms", lo, hi)) + } + print("\n── Transcription ──────────────────────────────────────────────────────") + if let tokenizer = try? await AutoTokenizer.from(pretrained: "openai/whisper-large-v3-turbo") { + let ids = tokens.filter { $0 < 50257 }.map { Int($0) } + print(" \(tokenizer.decode(tokens: ids))") + } else { + print(" (tokenizer unavailable — token ids: \(tokens))") + } +} + +// MARK: - Split runner (encoder + decoder with KV cache) + +func runSplit(bundleDir: String, audioPath: String?) async throws { + print("Format: split (encoder + decoder, KV cache)") + + let encModel = try await AIModel(contentsOf: URL(fileURLWithPath: "\(bundleDir)/encoder.aimodel")) + let decModel = try await AIModel(contentsOf: URL(fileURLWithPath: "\(bundleDir)/decoder.aimodel")) + + guard let encFn = try encModel.loadFunction(named: "main"), + let decFn = try decModel.loadFunction(named: "main") + else { fatalError("No 'main' function") } + + let encDesc = encModel.functionDescriptor(for: "main")! + let decDesc = decModel.functionDescriptor(for: "main")! + + guard case .ndArray(let melNDDesc) = encDesc.inputDescriptor(of: "input_features"), + case .ndArray(let encOutNDDesc) = encDesc.outputDescriptor(of: "encoder_hidden_states") + else { fatalError("Unexpected encoder descriptors") } + + let encOutShape = encOutNDDesc.shape + + var melArray: NDArray + if let path = audioPath { + melArray = try loadMelArray(from: path, descriptor: melNDDesc) + } else { + print("No audio — using silence for benchmarking") + melArray = NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])) + fillNDArray(&melArray, as: Float.self, count: melElements) { _ in 0.0 } + } + var encOutArray = NDArray(descriptor: encOutNDDesc.resolvingDynamicDimensions(encOutShape)) + + // Warmup + do { + var out = InferenceFunction.MutableViews() + out.insert(&encOutArray, for: "encoder_hidden_states") + _ = try await encFn.run(inputs: ["input_features": melArray], + states: InferenceFunction.MutableViews(), outputViews: consume out) + } + print("\n── Encoder ────────────────────────────────────────────────────────────") + let encT0 = Date() + do { + var out = InferenceFunction.MutableViews() + out.insert(&encOutArray, for: "encoder_hidden_states") + _ = try await encFn.run(inputs: ["input_features": melArray], + states: InferenceFunction.MutableViews(), outputViews: consume out) + } + print(String(format: " latency: %.1f ms", Date().timeIntervalSince(encT0) * 1000)) + + guard case .ndArray(let inputIdsNDDesc) = decDesc.inputDescriptor(of: "input_ids"), + case .ndArray(let posIdsNDDesc) = decDesc.inputDescriptor(of: "position_ids"), + case .ndArray(let encHSNDDesc) = decDesc.inputDescriptor(of: "encoder_hidden_states"), + case .ndArray(let keyCacheNDDesc) = decDesc.stateDescriptor(of: "keyCache"), + case .ndArray(let valCacheNDDesc) = decDesc.stateDescriptor(of: "valueCache"), + case .ndArray(let logitsNDDesc) = decDesc.outputDescriptor(of: "logits") + else { fatalError("Unexpected decoder descriptors") } + + let vocabSize = logitsNDDesc.shape.last! + let kcShape = keyCacheNDDesc.shape.map { $0 < 0 ? maxTargetPositions : $0 } + let vcShape = valCacheNDDesc.shape.map { $0 < 0 ? maxTargetPositions : $0 } + var keyCache = NDArray(descriptor: keyCacheNDDesc.resolvingDynamicDimensions(kcShape)) + var valueCache = NDArray(descriptor: valCacheNDDesc.resolvingDynamicDimensions(vcShape)) + + let encFlat = readNDArray(encOutArray, as: Float.self, count: encOutShape.reduce(1, *)) + var encHSArray = NDArray(descriptor: encHSNDDesc.resolvingDynamicDimensions(encOutShape)) + fillNDArray(&encHSArray, as: Float.self, with: encFlat) + var logitsArray = NDArray(descriptor: logitsNDDesc.resolvingDynamicDimensions([1, 1, vocabSize])) + + print("\n── Decoder ────────────────────────────────────────────────────────────") + + var tokens: [Int32] = forcedPrefix + var pos = 0 + for tok in forcedPrefix { + var ids = NDArray(descriptor: inputIdsNDDesc.resolvingDynamicDimensions([1, 1])) + var posIds = NDArray(descriptor: posIdsNDDesc.resolvingDynamicDimensions([1, pos + 1])) + fillNDArray(&ids, as: Int32.self, with: [tok]) + fillNDArray(&posIds, as: Int32.self, count: pos + 1) { Int32($0) } + var st = InferenceFunction.MutableViews() + st.insert(&keyCache, for: "keyCache"); st.insert(&valueCache, for: "valueCache") + var out = InferenceFunction.MutableViews(); out.insert(&logitsArray, for: "logits") + _ = try await decFn.run( + inputs: ["input_ids": ids, "position_ids": posIds, "encoder_hidden_states": encHSArray], + states: consume st, outputViews: consume out) + pos += 1 + } + + var stepTimesMs: [Double] = [] + while stepTimesMs.count < maxDecodeSteps { + var ids = NDArray(descriptor: inputIdsNDDesc.resolvingDynamicDimensions([1, 1])) + var posIds = NDArray(descriptor: posIdsNDDesc.resolvingDynamicDimensions([1, pos + 1])) + fillNDArray(&ids, as: Int32.self, with: [tokens.last!]) + fillNDArray(&posIds, as: Int32.self, count: pos + 1) { Int32($0) } + var st = InferenceFunction.MutableViews() + st.insert(&keyCache, for: "keyCache"); st.insert(&valueCache, for: "valueCache") + var out = InferenceFunction.MutableViews(); out.insert(&logitsArray, for: "logits") + let t0 = Date() + _ = try await decFn.run( + inputs: ["input_ids": ids, "position_ids": posIds, "encoder_hidden_states": encHSArray], + states: consume st, outputViews: consume out) + stepTimesMs.append(Date().timeIntervalSince(t0) * 1000) + let logits = flattenAsFloat(logitsArray) + let next = Int32(logits.indices.max(by: { logits[$0] < logits[$1] })!) + tokens.append(next); pos += 1 + if next == eotToken { break } + } + + await printResults(tokens: tokens, stepTimesMs: stepTimesMs) +} + +// MARK: - Legacy runner (monolithic model, no KV cache) + +func runLegacy(modelPath: String, audioPath: String?) async throws { + print("Format: legacy (monolithic, no KV cache)") + + let model = try await AIModel(contentsOf: URL(fileURLWithPath: modelPath)) + guard let fn = try model.loadFunction(named: "main") else { fatalError("No 'main' function") } + let desc = model.functionDescriptor(for: "main")! + + guard case .ndArray(let melNDDesc) = desc.inputDescriptor(of: "input_features"), + case .ndArray(let idsNDDesc) = desc.inputDescriptor(of: "decoder_input_ids"), + case .ndArray(let logitsDesc) = desc.outputDescriptor(of: "logits") + else { fatalError("Unexpected model descriptors") } + + let vocabSize = logitsDesc.shape.last! + let isStaticIds = !idsNDDesc.shape.contains(where: { $0 < 0 }) + if isStaticIds { + print("decoder_input_ids exported with static shape — no past context per step") + print("Re-export with --mode legacy to get dynamic shapes") + } + + var melArray: NDArray + if let path = audioPath { + melArray = try loadMelArray(from: path, descriptor: melNDDesc) + } else { + print("No audio — using silence for benchmarking") + melArray = NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])) + fillNDArray(&melArray, as: Float.self, count: melElements) { _ in 0.0 } + } + + print("\n── Decode ─────────────────────────────────────────────────────────────") + + var tokens: [Int32] = forcedPrefix + var stepTimesMs: [Double] = [] + + while stepTimesMs.count < maxDecodeSteps { + let inputTokens: [Int32] = isStaticIds ? [tokens.last!] : tokens + let seqLen = inputTokens.count + var ids = NDArray(descriptor: idsNDDesc.resolvingDynamicDimensions([1, seqLen])) + fillNDArray(&ids, as: Int32.self, with: inputTokens) + var logitsArray = NDArray(descriptor: logitsDesc.resolvingDynamicDimensions([1, seqLen, vocabSize])) + var out = InferenceFunction.MutableViews(); out.insert(&logitsArray, for: "logits") + let t0 = Date() + _ = try await fn.run( + inputs: ["input_features": melArray, "decoder_input_ids": ids], + states: InferenceFunction.MutableViews(), outputViews: consume out) + stepTimesMs.append(Date().timeIntervalSince(t0) * 1000) + let logits = flattenAsFloat(logitsArray) + let lastStart = (seqLen - 1) * vocabSize + let lastLogits = Array(logits[lastStart ..< lastStart + vocabSize]) + let next = Int32(lastLogits.indices.max(by: { lastLogits[$0] < lastLogits[$1] })!) + tokens.append(next) + if next == eotToken { break } + } + + await printResults(tokens: tokens, stepTimesMs: stepTimesMs) +} From 08cf8f9ecdf42860773d9aecac320261a150b26c Mon Sep 17 00:00:00 2001 From: Carina Peng Date: Fri, 12 Jun 2026 23:06:35 -0700 Subject: [PATCH 2/7] Mel compute in Swift --- .../speech-runner/SpeechRunnerMain.swift | 9 +- .../Tools/speech-runner/WhisperMel.swift | 182 ++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 swift/Sources/Tools/speech-runner/WhisperMel.swift diff --git a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift index cb3123a..d5bcebe 100644 --- a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift +++ b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift @@ -81,7 +81,14 @@ private func printResults(tokens: [Int32], stepTimesMs: [Double]) async { print(String(format: " min/max: %.1f / %.1f ms", lo, hi)) } print("\n── Transcription ──────────────────────────────────────────────────────") - if let tokenizer = try? await AutoTokenizer.from(pretrained: "openai/whisper-large-v3-turbo") { + // Load tokenizer from local HF cache (no network needed) + let cacheBase = FileManager.default.homeDirectoryForCurrentUser + .appending(path: ".cache/huggingface/hub/models--openai--whisper-large-v3-turbo/snapshots") + let snapshot = (try? FileManager.default.contentsOfDirectory(atPath: cacheBase.path))?.first + let tokenizerURL = snapshot.map { cacheBase.appending(path: $0) } + + if let url = tokenizerURL, + let tokenizer = try? await AutoTokenizer.from(modelFolder: url) { let ids = tokens.filter { $0 < 50257 }.map { Int($0) } print(" \(tokenizer.decode(tokens: ids))") } else { diff --git a/swift/Sources/Tools/speech-runner/WhisperMel.swift b/swift/Sources/Tools/speech-runner/WhisperMel.swift new file mode 100644 index 0000000..a4c9ada --- /dev/null +++ b/swift/Sources/Tools/speech-runner/WhisperMel.swift @@ -0,0 +1,182 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import Accelerate +import AVFoundation +import Foundation + +// Whisper mel spectrogram: sr=16000, n_fft=400, hop=160, n_mels=128 +// Slaney-normalised filterbank, reflect-padded audio, matches WhisperFeatureExtractor. +// +// vDSP DFT only supports f×2^n sizes (f ∈ {1,3,5,15}); 400=5²×2⁴ doesn't qualify. +// We precompute 201×400 DFT basis matrices and apply them with cblas_sgemv instead. + +enum WhisperMel { + + static let sampleRate: Double = 16_000 + static let nFFT = 400 // analysis window (samples) + static let hopLength = 160 + static let nMelBins = 128 + static let nFrames = 3_000 + static let nSamples = 480_000 + + private static let nFreqs = nFFT / 2 + 1 // 201 + + // MARK: - Public + + static func fromFile(_ url: URL) throws -> [Float] { + return fromPCM(try loadAndResample(url)) + } + + // MARK: - Audio loading + resampling + + static func loadAndResample(_ url: URL) throws -> [Float] { + let file = try AVAudioFile(forReading: url) + let fmt = AVAudioFormat(commonFormat: .pcmFormatFloat32, + sampleRate: sampleRate, channels: 1, interleaved: false)! + guard let conv = AVAudioConverter(from: file.processingFormat, to: fmt) else { + throw NSError(domain: "WhisperMel", code: 1, + userInfo: [NSLocalizedDescriptionKey: + "Cannot convert \(file.processingFormat) → 16 kHz mono"]) + } + let cap = AVAudioFrameCount( + ceil(Double(file.length) * sampleRate / file.processingFormat.sampleRate) + 1) + let out = AVAudioPCMBuffer(pcmFormat: fmt, frameCapacity: cap)! + var fed = false; var err: NSError? + conv.convert(to: out, error: &err) { _, status in + guard !fed else { status.pointee = .endOfStream; return nil } + fed = true + let buf = AVAudioPCMBuffer(pcmFormat: file.processingFormat, + frameCapacity: AVAudioFrameCount(file.length))! + try? file.read(into: buf) + status.pointee = buf.frameLength > 0 ? .haveData : .endOfStream + return buf + } + if let e = err { throw e } + return Array(UnsafeBufferPointer(start: out.floatChannelData![0], + count: Int(out.frameLength))) + } + + // MARK: - Precomputed DFT basis (201 × 400) + // cos_basis[k, n] = cos(2π k n / 400) → Y[k].real = cos_basis @ x + // sin_basis[k, n] = -sin(2π k n / 400) → Y[k].imag = sin_basis @ x + + static let cosBasis: [Float] = { + var m = [Float](repeating: 0, count: (nFFT / 2 + 1) * nFFT) + for k in 0...nFFT / 2 { + for n in 0.. [Float] { + // 1. Trim / zero-pad to nSamples + var audio = raw + if audio.count > nSamples { audio = Array(audio.prefix(nSamples)) } + else if audio.count < nSamples { + audio += [Float](repeating: 0, count: nSamples - audio.count) + } + + // 2. Reflect-pad by nFFT/2 (matches np.pad(..., mode='reflect')) + let pad = nFFT / 2 // 200 + var padded = [Float](repeating: 0, count: nSamples + 2 * pad) + for i in 0.. [Float] { + let fMax: Float = Float(sampleRate) / 2 // 8000 Hz + + func hzToMel(_ f: Float) -> Float { 2595 * log10(1 + f / 700) } + func melToHz(_ m: Float) -> Float { 700 * (pow(10, m / 2595) - 1) } + + let melMin = hzToMel(0), melMax = hzToMel(fMax) + let nPts = nMelBins + 2 + let pts = (0.. Float in + melToHz(melMin + Float(i) / Float(nPts - 1) * (melMax - melMin)) + } + // FFT bin frequencies for n_fft = 400 + let fftFreqs = (0..= fL && f <= fC { fb[m * nFreqs + k] = norm * (f - fL) / (fC - fL) } + else if f > fC && f <= fR { fb[m * nFreqs + k] = norm * (fR - f) / (fR - fC) } + } + } + return fb + } +} From dab75745be1e224150f3291ee2841126bb88e56b Mon Sep 17 00:00:00 2001 From: Carina Peng Date: Wed, 17 Jun 2026 12:51:46 -0700 Subject: [PATCH 3/7] Swift formatting --- .../speech-runner/SpeechRunnerMain.swift | 57 ++++++---- .../Tools/speech-runner/WhisperMel.swift | 107 +++++++++++------- 2 files changed, 97 insertions(+), 67 deletions(-) diff --git a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift index d5bcebe..127ea1e 100644 --- a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift +++ b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift @@ -88,7 +88,8 @@ private func printResults(tokens: [Int32], stepTimesMs: [Double]) async { let tokenizerURL = snapshot.map { cacheBase.appending(path: $0) } if let url = tokenizerURL, - let tokenizer = try? await AutoTokenizer.from(modelFolder: url) { + let tokenizer = try? await AutoTokenizer.from(modelFolder: url) + { let ids = tokens.filter { $0 < 50257 }.map { Int($0) } print(" \(tokenizer.decode(tokens: ids))") } else { @@ -105,14 +106,14 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { let decModel = try await AIModel(contentsOf: URL(fileURLWithPath: "\(bundleDir)/decoder.aimodel")) guard let encFn = try encModel.loadFunction(named: "main"), - let decFn = try decModel.loadFunction(named: "main") + let decFn = try decModel.loadFunction(named: "main") else { fatalError("No 'main' function") } let encDesc = encModel.functionDescriptor(for: "main")! let decDesc = decModel.functionDescriptor(for: "main")! - guard case .ndArray(let melNDDesc) = encDesc.inputDescriptor(of: "input_features"), - case .ndArray(let encOutNDDesc) = encDesc.outputDescriptor(of: "encoder_hidden_states") + guard case .ndArray(let melNDDesc) = encDesc.inputDescriptor(of: "input_features"), + case .ndArray(let encOutNDDesc) = encDesc.outputDescriptor(of: "encoder_hidden_states") else { fatalError("Unexpected encoder descriptors") } let encOutShape = encOutNDDesc.shape @@ -131,31 +132,33 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { do { var out = InferenceFunction.MutableViews() out.insert(&encOutArray, for: "encoder_hidden_states") - _ = try await encFn.run(inputs: ["input_features": melArray], - states: InferenceFunction.MutableViews(), outputViews: consume out) + _ = try await encFn.run( + inputs: ["input_features": melArray], + states: InferenceFunction.MutableViews(), outputViews: consume out) } print("\n── Encoder ────────────────────────────────────────────────────────────") let encT0 = Date() do { var out = InferenceFunction.MutableViews() out.insert(&encOutArray, for: "encoder_hidden_states") - _ = try await encFn.run(inputs: ["input_features": melArray], - states: InferenceFunction.MutableViews(), outputViews: consume out) + _ = try await encFn.run( + inputs: ["input_features": melArray], + states: InferenceFunction.MutableViews(), outputViews: consume out) } print(String(format: " latency: %.1f ms", Date().timeIntervalSince(encT0) * 1000)) guard case .ndArray(let inputIdsNDDesc) = decDesc.inputDescriptor(of: "input_ids"), - case .ndArray(let posIdsNDDesc) = decDesc.inputDescriptor(of: "position_ids"), - case .ndArray(let encHSNDDesc) = decDesc.inputDescriptor(of: "encoder_hidden_states"), - case .ndArray(let keyCacheNDDesc) = decDesc.stateDescriptor(of: "keyCache"), - case .ndArray(let valCacheNDDesc) = decDesc.stateDescriptor(of: "valueCache"), - case .ndArray(let logitsNDDesc) = decDesc.outputDescriptor(of: "logits") + case .ndArray(let posIdsNDDesc) = decDesc.inputDescriptor(of: "position_ids"), + case .ndArray(let encHSNDDesc) = decDesc.inputDescriptor(of: "encoder_hidden_states"), + case .ndArray(let keyCacheNDDesc) = decDesc.stateDescriptor(of: "keyCache"), + case .ndArray(let valCacheNDDesc) = decDesc.stateDescriptor(of: "valueCache"), + case .ndArray(let logitsNDDesc) = decDesc.outputDescriptor(of: "logits") else { fatalError("Unexpected decoder descriptors") } let vocabSize = logitsNDDesc.shape.last! let kcShape = keyCacheNDDesc.shape.map { $0 < 0 ? maxTargetPositions : $0 } let vcShape = valCacheNDDesc.shape.map { $0 < 0 ? maxTargetPositions : $0 } - var keyCache = NDArray(descriptor: keyCacheNDDesc.resolvingDynamicDimensions(kcShape)) + var keyCache = NDArray(descriptor: keyCacheNDDesc.resolvingDynamicDimensions(kcShape)) var valueCache = NDArray(descriptor: valCacheNDDesc.resolvingDynamicDimensions(vcShape)) let encFlat = readNDArray(encOutArray, as: Float.self, count: encOutShape.reduce(1, *)) @@ -173,8 +176,10 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { fillNDArray(&ids, as: Int32.self, with: [tok]) fillNDArray(&posIds, as: Int32.self, count: pos + 1) { Int32($0) } var st = InferenceFunction.MutableViews() - st.insert(&keyCache, for: "keyCache"); st.insert(&valueCache, for: "valueCache") - var out = InferenceFunction.MutableViews(); out.insert(&logitsArray, for: "logits") + st.insert(&keyCache, for: "keyCache") + st.insert(&valueCache, for: "valueCache") + var out = InferenceFunction.MutableViews() + out.insert(&logitsArray, for: "logits") _ = try await decFn.run( inputs: ["input_ids": ids, "position_ids": posIds, "encoder_hidden_states": encHSArray], states: consume st, outputViews: consume out) @@ -188,8 +193,10 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { fillNDArray(&ids, as: Int32.self, with: [tokens.last!]) fillNDArray(&posIds, as: Int32.self, count: pos + 1) { Int32($0) } var st = InferenceFunction.MutableViews() - st.insert(&keyCache, for: "keyCache"); st.insert(&valueCache, for: "valueCache") - var out = InferenceFunction.MutableViews(); out.insert(&logitsArray, for: "logits") + st.insert(&keyCache, for: "keyCache") + st.insert(&valueCache, for: "valueCache") + var out = InferenceFunction.MutableViews() + out.insert(&logitsArray, for: "logits") let t0 = Date() _ = try await decFn.run( inputs: ["input_ids": ids, "position_ids": posIds, "encoder_hidden_states": encHSArray], @@ -197,7 +204,8 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { stepTimesMs.append(Date().timeIntervalSince(t0) * 1000) let logits = flattenAsFloat(logitsArray) let next = Int32(logits.indices.max(by: { logits[$0] < logits[$1] })!) - tokens.append(next); pos += 1 + tokens.append(next) + pos += 1 if next == eotToken { break } } @@ -213,9 +221,9 @@ func runLegacy(modelPath: String, audioPath: String?) async throws { guard let fn = try model.loadFunction(named: "main") else { fatalError("No 'main' function") } let desc = model.functionDescriptor(for: "main")! - guard case .ndArray(let melNDDesc) = desc.inputDescriptor(of: "input_features"), - case .ndArray(let idsNDDesc) = desc.inputDescriptor(of: "decoder_input_ids"), - case .ndArray(let logitsDesc) = desc.outputDescriptor(of: "logits") + guard case .ndArray(let melNDDesc) = desc.inputDescriptor(of: "input_features"), + case .ndArray(let idsNDDesc) = desc.inputDescriptor(of: "decoder_input_ids"), + case .ndArray(let logitsDesc) = desc.outputDescriptor(of: "logits") else { fatalError("Unexpected model descriptors") } let vocabSize = logitsDesc.shape.last! @@ -245,7 +253,8 @@ func runLegacy(modelPath: String, audioPath: String?) async throws { var ids = NDArray(descriptor: idsNDDesc.resolvingDynamicDimensions([1, seqLen])) fillNDArray(&ids, as: Int32.self, with: inputTokens) var logitsArray = NDArray(descriptor: logitsDesc.resolvingDynamicDimensions([1, seqLen, vocabSize])) - var out = InferenceFunction.MutableViews(); out.insert(&logitsArray, for: "logits") + var out = InferenceFunction.MutableViews() + out.insert(&logitsArray, for: "logits") let t0 = Date() _ = try await fn.run( inputs: ["input_features": melArray, "decoder_input_ids": ids], @@ -253,7 +262,7 @@ func runLegacy(modelPath: String, audioPath: String?) async throws { stepTimesMs.append(Date().timeIntervalSince(t0) * 1000) let logits = flattenAsFloat(logitsArray) let lastStart = (seqLen - 1) * vocabSize - let lastLogits = Array(logits[lastStart ..< lastStart + vocabSize]) + let lastLogits = Array(logits[lastStart.. [Float] { let file = try AVAudioFile(forReading: url) - let fmt = AVAudioFormat(commonFormat: .pcmFormatFloat32, - sampleRate: sampleRate, channels: 1, interleaved: false)! + let fmt = AVAudioFormat( + commonFormat: .pcmFormatFloat32, + sampleRate: sampleRate, channels: 1, interleaved: false)! guard let conv = AVAudioConverter(from: file.processingFormat, to: fmt) else { - throw NSError(domain: "WhisperMel", code: 1, - userInfo: [NSLocalizedDescriptionKey: - "Cannot convert \(file.processingFormat) → 16 kHz mono"]) + throw NSError( + domain: "WhisperMel", code: 1, + userInfo: [ + NSLocalizedDescriptionKey: + "Cannot convert \(file.processingFormat) → 16 kHz mono" + ]) } let cap = AVAudioFrameCount( ceil(Double(file.length) * sampleRate / file.processingFormat.sampleRate) + 1) let out = AVAudioPCMBuffer(pcmFormat: fmt, frameCapacity: cap)! - var fed = false; var err: NSError? + var fed = false + var err: NSError? conv.convert(to: out, error: &err) { _, status in - guard !fed else { status.pointee = .endOfStream; return nil } + guard !fed else { + status.pointee = .endOfStream + return nil + } fed = true - let buf = AVAudioPCMBuffer(pcmFormat: file.processingFormat, - frameCapacity: AVAudioFrameCount(file.length))! + let buf = AVAudioPCMBuffer( + pcmFormat: file.processingFormat, + frameCapacity: AVAudioFrameCount(file.length))! try? file.read(into: buf) status.pointee = buf.frameLength > 0 ? .haveData : .endOfStream return buf } if let e = err { throw e } - return Array(UnsafeBufferPointer(start: out.floatChannelData![0], - count: Int(out.frameLength))) + return Array( + UnsafeBufferPointer( + start: out.floatChannelData![0], + count: Int(out.frameLength))) } // MARK: - Precomputed DFT basis (201 × 400) @@ -92,17 +102,18 @@ enum WhisperMel { static func fromPCM(_ raw: [Float]) -> [Float] { // 1. Trim / zero-pad to nSamples var audio = raw - if audio.count > nSamples { audio = Array(audio.prefix(nSamples)) } - else if audio.count < nSamples { + if audio.count > nSamples { + audio = Array(audio.prefix(nSamples)) + } else if audio.count < nSamples { audio += [Float](repeating: 0, count: nSamples - audio.count) } // 2. Reflect-pad by nFFT/2 (matches np.pad(..., mode='reflect')) let pad = nFFT / 2 // 200 var padded = [Float](repeating: 0, count: nSamples + 2 * pad) - for i in 0.. Float { 2595 * log10(1 + f / 700) } func melToHz(_ m: Float) -> Float { 700 * (pow(10, m / 2595) - 1) } - let melMin = hzToMel(0), melMax = hzToMel(fMax) + let melMin = hzToMel(0) + let melMax = hzToMel(fMax) let nPts = nMelBins + 2 let pts = (0.. Float in melToHz(melMin + Float(i) / Float(nPts - 1) * (melMax - melMin)) @@ -169,12 +185,17 @@ enum WhisperMel { var fb = [Float](repeating: 0, count: nMelBins * nFreqs) for m in 0..= fL && f <= fC { fb[m * nFreqs + k] = norm * (f - fL) / (fC - fL) } - else if f > fC && f <= fR { fb[m * nFreqs + k] = norm * (fR - f) / (fR - fC) } + if f >= fL && f <= fC { + fb[m * nFreqs + k] = norm * (f - fL) / (fC - fL) + } else if f > fC && f <= fR { + fb[m * nFreqs + k] = norm * (fR - f) / (fR - fC) + } } } return fb From eeb6992a5df9b2442d69b5839e10902b9aeeb17c Mon Sep 17 00:00:00 2001 From: Carina Peng Date: Thu, 18 Jun 2026 13:41:02 -0700 Subject: [PATCH 4/7] Updates --- Package.swift | 1 + .../speech-runner/SpeechRunnerMain.swift | 118 ++++++++++-------- 2 files changed, 67 insertions(+), 52 deletions(-) diff --git a/Package.swift b/Package.swift index 3ba2f25..2449058 100644 --- a/Package.swift +++ b/Package.swift @@ -159,6 +159,7 @@ let package = Package( name: "speech-runner", dependencies: [ "CoreAIShared", + .product(name: "ArgumentParser", package: "swift-argument-parser"), .product(name: "Transformers", package: "swift-transformers"), ], path: "swift/Sources/Tools/speech-runner", diff --git a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift index 127ea1e..f4c8e2c 100644 --- a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift +++ b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift @@ -3,6 +3,7 @@ // Use of this source code is governed by a BSD-3-clause license that can // be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause +import ArgumentParser import CoreAI import CoreAIShared import Foundation @@ -17,32 +18,25 @@ private let melElements = 128 * 3000 // MARK: - Entry point -// Usage: speech-runner [audio-or-mel] -// -// model-path A bundle dir with encoder.aimodel + decoder.aimodel (--mode coreai), -// or a single .aimodel file (--mode legacy). -// -// audio-or-mel An audio file (wav, flac, m4a, …) or a precomputed mel .bin -// from tools/compute_mel.py. Omit for silence benchmarking. @main -struct Main { - static func main() async { - guard CommandLine.arguments.count > 1 else { - print("Usage: speech-runner [audio-or-mel]") - exit(1) - } - let modelPath = CommandLine.arguments[1] - let audioPath = CommandLine.arguments.count > 2 ? CommandLine.arguments[2] : nil - do { - let encURL = URL(fileURLWithPath: "\(modelPath)/encoder.aimodel") - if FileManager.default.fileExists(atPath: encURL.path) { - try await runSplit(bundleDir: modelPath, audioPath: audioPath) - } else { - try await runLegacy(modelPath: modelPath, audioPath: audioPath) - } - } catch { - print("Fatal: \(error)") - exit(1) +struct SpeechRunner: AsyncParsableCommand { + static let configuration = CommandConfiguration( + commandName: "speech-runner", + abstract: "Transcribe audio using a CoreAI Whisper export" + ) + + @Argument(help: "Path to encoder+decoder bundle dir (--mode coreai) or single .aimodel (--mode legacy)") + var modelPath: String + + @Argument(help: "Audio file (wav, flac, m4a, …) or precomputed mel .bin. Omit for silence benchmarking.") + var audioPath: String? + + func run() async throws { + let encURL = URL(fileURLWithPath: "\(modelPath)/encoder.aimodel") + if FileManager.default.fileExists(atPath: encURL.path) { + try await runSplit(bundleDir: modelPath, audioPath: audioPath) + } else { + try await runLegacy(modelPath: modelPath, audioPath: audioPath) } } } @@ -61,7 +55,7 @@ private func loadMelArray(from path: String, descriptor: NDArrayDescriptor) thro let data = try Data(contentsOf: url) let count = data.count / MemoryLayout.size guard count == melElements else { - fatalError("mel bin has \(count) floats, expected \(melElements) (128×3000)") + throw ValidationError("mel bin has \(count) floats, expected \(melElements) (128×3000)") } floats = data.withUnsafeBytes { Array($0.bindMemory(to: Float.self)) } } @@ -72,7 +66,7 @@ private func loadMelArray(from path: String, descriptor: NDArrayDescriptor) thro // MARK: - Results -private func printResults(tokens: [Int32], stepTimesMs: [Double]) async { +private func printResults(tokens: [Int32], stepTimesMs: [Double]) async throws { let avgMs = stepTimesMs.reduce(0, +) / Double(stepTimesMs.count) print(String(format: " steps: %d", stepTimesMs.count)) print(String(format: " latency: %.1f ms/tok", avgMs)) @@ -81,20 +75,25 @@ private func printResults(tokens: [Int32], stepTimesMs: [Double]) async { print(String(format: " min/max: %.1f / %.1f ms", lo, hi)) } print("\n── Transcription ──────────────────────────────────────────────────────") - // Load tokenizer from local HF cache (no network needed) let cacheBase = FileManager.default.homeDirectoryForCurrentUser .appending(path: ".cache/huggingface/hub/models--openai--whisper-large-v3-turbo/snapshots") - let snapshot = (try? FileManager.default.contentsOfDirectory(atPath: cacheBase.path))?.first - let tokenizerURL = snapshot.map { cacheBase.appending(path: $0) } - - if let url = tokenizerURL, - let tokenizer = try? await AutoTokenizer.from(modelFolder: url) - { - let ids = tokens.filter { $0 < 50257 }.map { Int($0) } - print(" \(tokenizer.decode(tokens: ids))") - } else { - print(" (tokenizer unavailable — token ids: \(tokens))") + guard + let snapshot = (try? FileManager.default.contentsOfDirectory(atPath: cacheBase.path))?.first, + let tokenizer = try? await AutoTokenizer.from(modelFolder: cacheBase.appending(path: snapshot)) + else { + throw RuntimeError("Tokenizer not found — run the model export once to populate the HF cache") } + let ids = tokens.filter { $0 < 50257 }.map { Int($0) } + print(" \(tokenizer.decode(tokens: ids))") +} + +struct RuntimeError: Error, CustomStringConvertible { + let description: String + init(_ message: String) { description = message } +} + +extension Duration { + var inMilliseconds: Double { Double(components.seconds) * 1000 + Double(components.attoseconds) / 1e15 } } // MARK: - Split runner (encoder + decoder with KV cache) @@ -107,14 +106,14 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { guard let encFn = try encModel.loadFunction(named: "main"), let decFn = try decModel.loadFunction(named: "main") - else { fatalError("No 'main' function") } + else { throw RuntimeError("No 'main' function in model") } let encDesc = encModel.functionDescriptor(for: "main")! let decDesc = decModel.functionDescriptor(for: "main")! guard case .ndArray(let melNDDesc) = encDesc.inputDescriptor(of: "input_features"), case .ndArray(let encOutNDDesc) = encDesc.outputDescriptor(of: "encoder_hidden_states") - else { fatalError("Unexpected encoder descriptors") } + else { throw RuntimeError("Unexpected encoder descriptors") } let encOutShape = encOutNDDesc.shape @@ -137,7 +136,7 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { states: InferenceFunction.MutableViews(), outputViews: consume out) } print("\n── Encoder ────────────────────────────────────────────────────────────") - let encT0 = Date() + let encT0 = ContinuousClock.now do { var out = InferenceFunction.MutableViews() out.insert(&encOutArray, for: "encoder_hidden_states") @@ -145,7 +144,8 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { inputs: ["input_features": melArray], states: InferenceFunction.MutableViews(), outputViews: consume out) } - print(String(format: " latency: %.1f ms", Date().timeIntervalSince(encT0) * 1000)) + let encMs = (ContinuousClock.now - encT0).inMilliseconds + print(String(format: " latency: %.1f ms", encMs)) guard case .ndArray(let inputIdsNDDesc) = decDesc.inputDescriptor(of: "input_ids"), case .ndArray(let posIdsNDDesc) = decDesc.inputDescriptor(of: "position_ids"), @@ -153,7 +153,7 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { case .ndArray(let keyCacheNDDesc) = decDesc.stateDescriptor(of: "keyCache"), case .ndArray(let valCacheNDDesc) = decDesc.stateDescriptor(of: "valueCache"), case .ndArray(let logitsNDDesc) = decDesc.outputDescriptor(of: "logits") - else { fatalError("Unexpected decoder descriptors") } + else { throw RuntimeError("Unexpected decoder descriptors") } let vocabSize = logitsNDDesc.shape.last! let kcShape = keyCacheNDDesc.shape.map { $0 < 0 ? maxTargetPositions : $0 } @@ -197,11 +197,11 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { st.insert(&valueCache, for: "valueCache") var out = InferenceFunction.MutableViews() out.insert(&logitsArray, for: "logits") - let t0 = Date() + let t0 = ContinuousClock.now _ = try await decFn.run( inputs: ["input_ids": ids, "position_ids": posIds, "encoder_hidden_states": encHSArray], states: consume st, outputViews: consume out) - stepTimesMs.append(Date().timeIntervalSince(t0) * 1000) + stepTimesMs.append((ContinuousClock.now - t0).inMilliseconds) let logits = flattenAsFloat(logitsArray) let next = Int32(logits.indices.max(by: { logits[$0] < logits[$1] })!) tokens.append(next) @@ -209,7 +209,7 @@ func runSplit(bundleDir: String, audioPath: String?) async throws { if next == eotToken { break } } - await printResults(tokens: tokens, stepTimesMs: stepTimesMs) + try await printResults(tokens: tokens, stepTimesMs: stepTimesMs) } // MARK: - Legacy runner (monolithic model, no KV cache) @@ -218,19 +218,33 @@ func runLegacy(modelPath: String, audioPath: String?) async throws { print("Format: legacy (monolithic, no KV cache)") let model = try await AIModel(contentsOf: URL(fileURLWithPath: modelPath)) - guard let fn = try model.loadFunction(named: "main") else { fatalError("No 'main' function") } + guard let fn = try model.loadFunction(named: "main") + else { throw RuntimeError("No 'main' function in model") } let desc = model.functionDescriptor(for: "main")! guard case .ndArray(let melNDDesc) = desc.inputDescriptor(of: "input_features"), case .ndArray(let idsNDDesc) = desc.inputDescriptor(of: "decoder_input_ids"), case .ndArray(let logitsDesc) = desc.outputDescriptor(of: "logits") - else { fatalError("Unexpected model descriptors") } + else { throw RuntimeError("Unexpected model descriptors") } let vocabSize = logitsDesc.shape.last! let isStaticIds = !idsNDDesc.shape.contains(where: { $0 < 0 }) if isStaticIds { print("decoder_input_ids exported with static shape — no past context per step") - print("Re-export with --mode legacy to get dynamic shapes") + print("Re-export with --mode legacy to fix") + } + + // Warmup pass + do { + var ids = NDArray(descriptor: idsNDDesc.resolvingDynamicDimensions([1, 1])) + fillNDArray(&ids, as: Int32.self, with: [forcedPrefix[0]]) + var logitsWarmup = NDArray(descriptor: logitsDesc.resolvingDynamicDimensions([1, 1, vocabSize])) + var out = InferenceFunction.MutableViews() + out.insert(&logitsWarmup, for: "logits") + _ = try await fn.run( + inputs: ["input_features": NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])), + "decoder_input_ids": ids], + states: InferenceFunction.MutableViews(), outputViews: consume out) } var melArray: NDArray @@ -255,11 +269,11 @@ func runLegacy(modelPath: String, audioPath: String?) async throws { var logitsArray = NDArray(descriptor: logitsDesc.resolvingDynamicDimensions([1, seqLen, vocabSize])) var out = InferenceFunction.MutableViews() out.insert(&logitsArray, for: "logits") - let t0 = Date() + let t0 = ContinuousClock.now _ = try await fn.run( inputs: ["input_features": melArray, "decoder_input_ids": ids], states: InferenceFunction.MutableViews(), outputViews: consume out) - stepTimesMs.append(Date().timeIntervalSince(t0) * 1000) + stepTimesMs.append((ContinuousClock.now - t0).inMilliseconds) let logits = flattenAsFloat(logitsArray) let lastStart = (seqLen - 1) * vocabSize let lastLogits = Array(logits[lastStart.. Date: Thu, 18 Jun 2026 14:46:52 -0700 Subject: [PATCH 5/7] Swift format --- swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift index f4c8e2c..e1006c4 100644 --- a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift +++ b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift @@ -242,8 +242,10 @@ func runLegacy(modelPath: String, audioPath: String?) async throws { var out = InferenceFunction.MutableViews() out.insert(&logitsWarmup, for: "logits") _ = try await fn.run( - inputs: ["input_features": NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])), - "decoder_input_ids": ids], + inputs: [ + "input_features": NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])), + "decoder_input_ids": ids, + ], states: InferenceFunction.MutableViews(), outputViews: consume out) } From e3805a221a536d2800c2b14bfc4811008d606e38 Mon Sep 17 00:00:00 2001 From: Carina Peng Date: Mon, 22 Jun 2026 15:19:59 -0700 Subject: [PATCH 6/7] Protocol abstraction --- Package.swift | 20 +- .../Sources/CoreAISpeech/MelSpectrogram.swift | 176 +++++++++++ swift/Sources/CoreAISpeech/SpeechBundle.swift | 125 ++++++++ .../Sources/CoreAISpeech/SpeechDecoder.swift | 97 ++++++ swift/Sources/CoreAISpeech/SpeechModel.swift | 130 ++++++++ .../speech-runner/SpeechRunnerMain.swift | 288 +++++------------- .../Tools/speech-runner/WhisperMel.swift | 203 ------------ 7 files changed, 628 insertions(+), 411 deletions(-) create mode 100644 swift/Sources/CoreAISpeech/MelSpectrogram.swift create mode 100644 swift/Sources/CoreAISpeech/SpeechBundle.swift create mode 100644 swift/Sources/CoreAISpeech/SpeechDecoder.swift create mode 100644 swift/Sources/CoreAISpeech/SpeechModel.swift delete mode 100644 swift/Sources/Tools/speech-runner/WhisperMel.swift diff --git a/Package.swift b/Package.swift index 2449058..5748c1d 100644 --- a/Package.swift +++ b/Package.swift @@ -29,6 +29,10 @@ let package = Package( "CoreAIImageSegmenter" ] ), + .library( + name: "CoreAISpeech", + targets: ["CoreAISpeech"] + ), .library( name: "CoreAIObjectDetection", targets: [ @@ -85,6 +89,19 @@ let package = Package( ] ), + // Speech recognition library + .target( + name: "CoreAISpeech", + dependencies: [ + "CoreAIShared", + .product(name: "Transformers", package: "swift-transformers"), + ], + path: "swift/Sources/CoreAISpeech", + swiftSettings: [ + .enableUpcomingFeature("MemberImportVisibility") + ] + ), + // Diffusion Pipeline .target( name: "CoreAIDiffusionPipeline", @@ -158,9 +175,8 @@ let package = Package( .executableTarget( name: "speech-runner", dependencies: [ - "CoreAIShared", + "CoreAISpeech", .product(name: "ArgumentParser", package: "swift-argument-parser"), - .product(name: "Transformers", package: "swift-transformers"), ], path: "swift/Sources/Tools/speech-runner", swiftSettings: [ diff --git a/swift/Sources/CoreAISpeech/MelSpectrogram.swift b/swift/Sources/CoreAISpeech/MelSpectrogram.swift new file mode 100644 index 0000000..40d380c --- /dev/null +++ b/swift/Sources/CoreAISpeech/MelSpectrogram.swift @@ -0,0 +1,176 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import AVFoundation +import Accelerate +import CoreAIShared +import Foundation + +// MARK: - MelConfig + +/// Parameters for mel spectrogram computation. +public struct MelConfig: Sendable { + public let sampleRate: Double + public let nFFT: Int + public let hopLength: Int + public let nMelBins: Int + public let nFrames: Int + + public var nSamples: Int { Int(sampleRate) * (nFrames * hopLength / Int(sampleRate / 100)) } + + /// Whisper / Parakeet shared parameters. + public static let whisper = MelConfig( + sampleRate: 16_000, nFFT: 400, hopLength: 160, nMelBins: 128, nFrames: 3_000) +} + +// MARK: - MelSpectrogram + +/// Computes a mel spectrogram from an audio file or raw PCM samples. +public enum MelSpectrogram { + // MARK: Public API + + public static func fromFile(_ url: URL, config: MelConfig = .whisper) throws -> [Float] { + return fromPCM(try loadAndResample(url, targetSampleRate: config.sampleRate), config: config) + } + + public static func fromPCM(_ raw: [Float], config: MelConfig = .whisper) -> [Float] { + let nSamples = config.nFrames * config.hopLength + + var audio = raw + if audio.count > nSamples { + audio = Array(audio.prefix(nSamples)) + } else if audio.count < nSamples { + audio += [Float](repeating: 0, count: nSamples - audio.count) + } + + let pad = config.nFFT / 2 + var padded = [Float](repeating: 0, count: nSamples + 2 * pad) + for i in 0.. [Float] { + let file = try AVAudioFile(forReading: url) + let fmt = AVAudioFormat( + commonFormat: .pcmFormatFloat32, + sampleRate: targetSampleRate, channels: 1, interleaved: false)! + guard let conv = AVAudioConverter(from: file.processingFormat, to: fmt) else { + throw SpeechError.invalidAudio( + "Cannot resample \(file.processingFormat) to \(targetSampleRate) Hz mono") + } + let cap = AVAudioFrameCount( + ceil(Double(file.length) * targetSampleRate / file.processingFormat.sampleRate) + 1) + let out = AVAudioPCMBuffer(pcmFormat: fmt, frameCapacity: cap)! + var fed = false + var err: NSError? + conv.convert(to: out, error: &err) { _, status in + guard !fed else { + status.pointee = .endOfStream + return nil + } + fed = true + let buf = AVAudioPCMBuffer( + pcmFormat: file.processingFormat, + frameCapacity: AVAudioFrameCount(file.length))! + try? file.read(into: buf) + status.pointee = buf.frameLength > 0 ? .haveData : .endOfStream + return buf + } + if let e = err { throw SpeechError.invalidAudio(e.localizedDescription) } + return Array( + UnsafeBufferPointer( + start: out.floatChannelData![0], + count: Int(out.frameLength))) + } + + // MARK: Precomputed basis + + private static func hannWindow(size: Int) -> [Float] { + (0.. ([Float], [Float]) { + let nFreqs = nFFT / 2 + 1 + var cos = [Float](repeating: 0, count: nFreqs * nFFT) + var sin = [Float](repeating: 0, count: nFreqs * nFFT) + for k in 0.. [Float] { + let nFreqs = config.nFFT / 2 + 1 + let fMax = Float(config.sampleRate) / 2 + func h2m(_ f: Float) -> Float { 2595 * log10(1 + f / 700) } + func m2h(_ m: Float) -> Float { 700 * (pow(10, m / 2595) - 1) } + let pts = (0.. Float in + m2h(h2m(0) + Float(i) / Float(config.nMelBins + 1) * (h2m(fMax) - h2m(0))) + } + let fftFreqs = (0..= fL && f <= fC { + fb[m * nFreqs + k] = norm * (f - fL) / (fC - fL) + } else if f > fC && f <= fR { + fb[m * nFreqs + k] = norm * (fR - f) / (fR - fC) + } + } + } + return fb + } +} diff --git a/swift/Sources/CoreAISpeech/SpeechBundle.swift b/swift/Sources/CoreAISpeech/SpeechBundle.swift new file mode 100644 index 0000000..8cfa101 --- /dev/null +++ b/swift/Sources/CoreAISpeech/SpeechBundle.swift @@ -0,0 +1,125 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import AVFoundation +import CoreAI +import CoreAIShared +import Foundation +import Tokenizers + +// MARK: - SpeechBundle + +/// Locates and loads the assets inside a CoreAISpeech model bundle directory. +/// +/// A bundle directory contains: +/// encoder.aimodel — audio features → encoder hidden states +/// decoder.aimodel — autoregressive decoder with persistent state +/// generation_config.json (optional) — prefix, EOT token, etc. +/// +/// The tokenizer is loaded from the local HF cache if it can be found there. +public struct SpeechBundle: Sendable { + public let encoder: AIModel + public let decoder: AIModel + public let tokenizer: (any Tokenizer)? + public let generationConfig: GenerationConfig + + public init(at url: URL) async throws { + let encURL = url.appending(path: "encoder.aimodel") + let decURL = url.appending(path: "decoder.aimodel") + guard FileManager.default.fileExists(atPath: encURL.path), + FileManager.default.fileExists(atPath: decURL.path) + else { + throw SpeechError.missingModel( + "bundle at \(url.lastPathComponent) must contain encoder.aimodel and decoder.aimodel") + } + encoder = try await AIModel(contentsOf: encURL) + decoder = try await AIModel(contentsOf: decURL) + + // Load generation config from bundle if present, otherwise use Whisper defaults + let cfgURL = url.appending(path: "generation_config.json") + generationConfig = (try? GenerationConfig(from: cfgURL)) ?? .whisper + + // Tokenizer — look in bundle first, then fall back to HF cache + tokenizer = try? await Self.loadTokenizer(bundleURL: url, config: generationConfig) + } + + private static func loadTokenizer( + bundleURL: URL, config: GenerationConfig + ) async throws -> (any Tokenizer)? { + // 1. Try tokenizer files in the bundle itself + if FileManager.default.fileExists(atPath: bundleURL.appending(path: "tokenizer.json").path) { + return try? await AutoTokenizer.from(modelFolder: bundleURL) + } + // 2. Fall back to local HF cache using the model name from config + if let name = config.tokenizerName { + let cacheRoot = FileManager.default.homeDirectoryForCurrentUser + .appending(path: ".cache/huggingface/hub") + let folderName = "models--" + name.replacingOccurrences(of: "/", with: "--") + let snapshotsDir = cacheRoot.appending(path: "\(folderName)/snapshots") + if let snapshot = try? FileManager.default.contentsOfDirectory( + atPath: snapshotsDir.path + ).first { + return try? await AutoTokenizer.from( + modelFolder: snapshotsDir.appending(path: snapshot)) + } + } + return nil + } +} + +// MARK: - GenerationConfig + +/// Model-specific generation parameters, read from generation_config.json in the bundle. +public struct GenerationConfig: Sendable { + /// Tokens prepended to every decode sequence before free generation. + public let forcedPrefix: [Int32] + /// Token that signals end of transcription. + public let eotToken: Int32 + /// Maximum tokens to generate per call. + public let maxDecodeSteps: Int + /// HuggingFace model name for loading the tokenizer from cache. + public let tokenizerName: String? + + /// Whisper large-v3-turbo defaults. + public static let whisper = GenerationConfig( + forcedPrefix: [50258, 50259, 50360, 50364], // BOS <|en|> <|transcribe|> <|notimestamps|> + eotToken: 50257, + maxDecodeSteps: 50, + tokenizerName: "openai/whisper-large-v3-turbo" + ) + + init(forcedPrefix: [Int32], eotToken: Int32, maxDecodeSteps: Int, tokenizerName: String?) { + self.forcedPrefix = forcedPrefix + self.eotToken = eotToken + self.maxDecodeSteps = maxDecodeSteps + self.tokenizerName = tokenizerName + } + + init(from url: URL) throws { + let data = try Data(contentsOf: url) + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] ?? [:] + forcedPrefix = (json["forced_decoder_ids"] as? [Int]).map { $0.map { Int32($0) } } ?? Self.whisper.forcedPrefix + eotToken = (json["eos_token_id"] as? Int).map { Int32($0) } ?? Self.whisper.eotToken + maxDecodeSteps = (json["max_new_tokens"] as? Int) ?? Self.whisper.maxDecodeSteps + tokenizerName = json["tokenizer_name"] as? String ?? Self.whisper.tokenizerName + } +} + +// MARK: - SpeechError + +public enum SpeechError: Error, CustomStringConvertible { + case missingModel(String) + case missingTokenizer + case invalidAudio(String) + + public var description: String { + switch self { + case .missingModel(let msg): return "Missing model: \(msg)" + case .missingTokenizer: + return "Tokenizer not found — ensure the model bundle includes a tokenizer or the HF cache is populated" + case .invalidAudio(let msg): return "Invalid audio: \(msg)" + } + } +} diff --git a/swift/Sources/CoreAISpeech/SpeechDecoder.swift b/swift/Sources/CoreAISpeech/SpeechDecoder.swift new file mode 100644 index 0000000..f8a2056 --- /dev/null +++ b/swift/Sources/CoreAISpeech/SpeechDecoder.swift @@ -0,0 +1,97 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import CoreAI +import CoreAIShared +import Foundation + +// MARK: - SpeechDecoder protocol + +/// Model-specific decode logic. +public protocol SpeechDecoder: Sendable { + func decode( + encoderOutput: NDArray, + encoderOutputShape: [Int], + decoderModel: AIModel, + config: GenerationConfig + ) async throws -> [Int32] +} + +// MARK: - WhisperDecoder + +/// Greedy decoder for Whisper (encoder-decoder, cross-attention, KV cache). +public struct WhisperDecoder: SpeechDecoder { + public init() {} + + public func decode( + encoderOutput: NDArray, + encoderOutputShape: [Int], + decoderModel: AIModel, + config: GenerationConfig + ) async throws -> [Int32] { + guard let decFn = try decoderModel.loadFunction(named: "main") else { + throw SpeechError.missingModel("No 'main' function in decoder") + } + let decDesc = decoderModel.functionDescriptor(for: "main")! + + guard case .ndArray(let inputIdsNDDesc) = decDesc.inputDescriptor(of: "input_ids"), + case .ndArray(let posIdsNDDesc) = decDesc.inputDescriptor(of: "position_ids"), + case .ndArray(let encHSNDDesc) = decDesc.inputDescriptor(of: "encoder_hidden_states"), + case .ndArray(let keyCacheNDDesc) = decDesc.stateDescriptor(of: "keyCache"), + case .ndArray(let valCacheNDDesc) = decDesc.stateDescriptor(of: "valueCache"), + case .ndArray(let logitsNDDesc) = decDesc.outputDescriptor(of: "logits") + else { throw SpeechError.missingModel("Unexpected decoder descriptors") } + + let vocabSize = logitsNDDesc.shape.last! + let maxTargetPos = 448 + let kcShape = keyCacheNDDesc.shape.map { $0 < 0 ? maxTargetPos : $0 } + let vcShape = valCacheNDDesc.shape.map { $0 < 0 ? maxTargetPos : $0 } + var keyCache = NDArray(descriptor: keyCacheNDDesc.resolvingDynamicDimensions(kcShape)) + var valueCache = NDArray(descriptor: valCacheNDDesc.resolvingDynamicDimensions(vcShape)) + + var encHSArray = NDArray(descriptor: encHSNDDesc.resolvingDynamicDimensions(encoderOutputShape)) + let encFlat = readNDArray(encoderOutput, as: Float.self, count: encoderOutputShape.reduce(1, *)) + fillNDArray(&encHSArray, as: Float.self, with: encFlat) + + var logitsArray = NDArray(descriptor: logitsNDDesc.resolvingDynamicDimensions([1, 1, vocabSize])) + + func step(_ tok: Int32, pos: Int) async throws { + var ids = NDArray(descriptor: inputIdsNDDesc.resolvingDynamicDimensions([1, 1])) + var posIds = NDArray(descriptor: posIdsNDDesc.resolvingDynamicDimensions([1, pos + 1])) + fillNDArray(&ids, as: Int32.self, with: [tok]) + fillNDArray(&posIds, as: Int32.self, count: pos + 1) { Int32($0) } + var st = InferenceFunction.MutableViews() + st.insert(&keyCache, for: "keyCache") + st.insert(&valueCache, for: "valueCache") + var out = InferenceFunction.MutableViews() + out.insert(&logitsArray, for: "logits") + _ = try await decFn.run( + inputs: [ + "input_ids": ids, "position_ids": posIds, + "encoder_hidden_states": encHSArray, + ], + states: consume st, outputViews: consume out) + } + + // Prime KV cache with forced prefix + var tokens: [Int32] = config.forcedPrefix + for (i, tok) in config.forcedPrefix.enumerated() { + try await step(tok, pos: i) + } + + // Greedy decode + var pos = config.forcedPrefix.count + while tokens.count - config.forcedPrefix.count < config.maxDecodeSteps { + try await step(tokens.last!, pos: pos) + let logits = flattenAsFloat(logitsArray) + let next = Int32(logits.indices.max(by: { logits[$0] < logits[$1] })!) + tokens.append(next) + pos += 1 + if next == config.eotToken { break } + } + + return tokens + } +} diff --git a/swift/Sources/CoreAISpeech/SpeechModel.swift b/swift/Sources/CoreAISpeech/SpeechModel.swift new file mode 100644 index 0000000..a61d28e --- /dev/null +++ b/swift/Sources/CoreAISpeech/SpeechModel.swift @@ -0,0 +1,130 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import CoreAI +import CoreAIShared +import Foundation +import Tokenizers + +// MARK: - SpeechModel + +/// On-device speech recognition model. +/// +/// Loads a CoreAISpeech bundle (encoder.aimodel + decoder.aimodel) and transcribes +/// audio files. The decoder architecture is pluggable via ``SpeechDecoder`` +public actor SpeechModel { + private let bundle: SpeechBundle + private let decoder: any SpeechDecoder + private let melConfig: MelConfig + + // Encoder function and descriptor, cached after first load + private var encFn: InferenceFunction? + private var encOutShape: [Int]? + + /// Load a model from a bundle directory. + /// + /// - Parameters: + /// - bundleURL: Directory containing encoder.aimodel and decoder.aimodel. + /// - decoder: Decode strategy. Defaults to ``WhisperDecoder``. + /// - melConfig: Mel spectrogram parameters. Defaults to ``MelConfig/whisper``. + public init( + bundleURL: URL, + decoder: any SpeechDecoder = WhisperDecoder(), + melConfig: MelConfig = .whisper + ) async throws { + self.bundle = try await SpeechBundle(at: bundleURL) + self.decoder = decoder + self.melConfig = melConfig + try await warmUp() + } + + // MARK: - Transcription + + /// Transcribe an audio file, returning the full text. + public func transcribe(audioURL: URL) async throws -> String { + let tokens = try await decodeAudio(from: audioURL) + return try detokenize(tokens) + } + + /// Transcribe raw 16 kHz mono PCM samples. + public func transcribe(pcm: [Float]) async throws -> String { + let tokens = try await decodeAudio(pcm: pcm) + return try detokenize(tokens) + } + + // MARK: - Internals + + private func warmUp() async throws { + // Run the encoder once with silence to trigger JIT compilation + guard let fn = try bundle.encoder.loadFunction(named: "main") else { + throw SpeechError.missingModel("No 'main' function in encoder") + } + encFn = fn + let encDesc = bundle.encoder.functionDescriptor(for: "main")! + guard case .ndArray(let encOutNDDesc) = encDesc.outputDescriptor(of: "encoder_hidden_states") + else { throw SpeechError.missingModel("Unexpected encoder output descriptor") } + encOutShape = encOutNDDesc.shape + + guard case .ndArray(let melNDDesc) = encDesc.inputDescriptor(of: "input_features") + else { throw SpeechError.missingModel("Unexpected encoder input descriptor") } + + var silence = NDArray( + descriptor: melNDDesc.resolvingDynamicDimensions([1, melConfig.nMelBins, melConfig.nFrames])) + fillNDArray(&silence, as: Float.self, count: melConfig.nMelBins * melConfig.nFrames) { _ in 0.0 } + var encOut = NDArray(descriptor: encOutNDDesc.resolvingDynamicDimensions(encOutNDDesc.shape)) + var out = InferenceFunction.MutableViews() + out.insert(&encOut, for: "encoder_hidden_states") + _ = try await fn.run( + inputs: ["input_features": silence], + states: InferenceFunction.MutableViews(), outputViews: consume out) + } + + private func runEncoder(_ melArray: inout NDArray) async throws -> NDArray { + guard let fn = encFn, let shape = encOutShape else { + throw SpeechError.missingModel("Encoder not initialised") + } + let encDesc = bundle.encoder.functionDescriptor(for: "main")! + guard case .ndArray(let encOutNDDesc) = encDesc.outputDescriptor(of: "encoder_hidden_states") + else { throw SpeechError.missingModel("Unexpected encoder output") } + var encOut = NDArray(descriptor: encOutNDDesc.resolvingDynamicDimensions(shape)) + var out = InferenceFunction.MutableViews() + out.insert(&encOut, for: "encoder_hidden_states") + _ = try await fn.run( + inputs: ["input_features": melArray], + states: InferenceFunction.MutableViews(), outputViews: consume out) + return encOut + } + + private func decodeAudio(from url: URL) async throws -> [Int32] { + let pcm = try MelSpectrogram.loadAndResample(url, targetSampleRate: melConfig.sampleRate) + return try await decodeAudio(pcm: pcm) + } + + private func decodeAudio(pcm: [Float]) async throws -> [Int32] { + let encDesc = bundle.encoder.functionDescriptor(for: "main")! + guard case .ndArray(let melNDDesc) = encDesc.inputDescriptor(of: "input_features") + else { throw SpeechError.missingModel("Unexpected encoder input") } + + let floats = MelSpectrogram.fromPCM(pcm, config: melConfig) + var melArray = NDArray( + descriptor: melNDDesc.resolvingDynamicDimensions( + [1, melConfig.nMelBins, melConfig.nFrames])) + fillNDArray(&melArray, as: Float.self, with: floats) + + let encoderOutput = try await runEncoder(&melArray) + let shape = encOutShape ?? [1, 1500, 1280] + return try await decoder.decode( + encoderOutput: encoderOutput, + encoderOutputShape: shape, + decoderModel: bundle.decoder, + config: bundle.generationConfig) + } + + private func detokenize(_ tokens: [Int32]) throws -> String { + guard let tokenizer = bundle.tokenizer else { throw SpeechError.missingTokenizer } + let ids = tokens.filter { $0 < bundle.generationConfig.eotToken }.map { Int($0) } + return tokenizer.decode(tokens: ids).trimmingCharacters(in: .whitespaces) + } +} diff --git a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift index e1006c4..14318d9 100644 --- a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift +++ b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift @@ -6,213 +6,60 @@ import ArgumentParser import CoreAI import CoreAIShared +import CoreAISpeech import Foundation import Tokenizers -// Whisper forced prefix: <|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> -private let forcedPrefix: [Int32] = [50258, 50259, 50360, 50364] -private let eotToken: Int32 = 50257 -private let maxTargetPositions = 448 -private let maxDecodeSteps = 50 -private let melElements = 128 * 3000 - // MARK: - Entry point @main struct SpeechRunner: AsyncParsableCommand { static let configuration = CommandConfiguration( commandName: "speech-runner", - abstract: "Transcribe audio using a CoreAI Whisper export" + abstract: "Transcribe audio using a CoreAI speech model bundle" ) - @Argument(help: "Path to encoder+decoder bundle dir (--mode coreai) or single .aimodel (--mode legacy)") + @Argument(help: "Bundle dir (encoder.aimodel + decoder.aimodel) or single .aimodel (legacy)") var modelPath: String - @Argument(help: "Audio file (wav, flac, m4a, …) or precomputed mel .bin. Omit for silence benchmarking.") + @Argument(help: "Audio file (wav, flac, m4a, …). Omit for latency benchmarking with silence.") var audioPath: String? func run() async throws { - let encURL = URL(fileURLWithPath: "\(modelPath)/encoder.aimodel") - if FileManager.default.fileExists(atPath: encURL.path) { - try await runSplit(bundleDir: modelPath, audioPath: audioPath) + let bundleURL = URL(fileURLWithPath: modelPath) + if FileManager.default.fileExists(atPath: bundleURL.appending(path: "encoder.aimodel").path) { + try await runBundle(bundleURL: bundleURL, audioPath: audioPath) } else { try await runLegacy(modelPath: modelPath, audioPath: audioPath) } } } -// MARK: - Mel loading - -private let audioExtensions: Set = ["wav", "flac", "m4a", "mp3", "aiff", "aif", "caf"] - -private func loadMelArray(from path: String, descriptor: NDArrayDescriptor) throws -> NDArray { - let url = URL(fileURLWithPath: path) - let floats: [Float] - if audioExtensions.contains(url.pathExtension.lowercased()) { - print("Computing mel from audio file…") - floats = try WhisperMel.fromFile(url) - } else { - let data = try Data(contentsOf: url) - let count = data.count / MemoryLayout.size - guard count == melElements else { - throw ValidationError("mel bin has \(count) floats, expected \(melElements) (128×3000)") - } - floats = data.withUnsafeBytes { Array($0.bindMemory(to: Float.self)) } - } - var array = NDArray(descriptor: descriptor.resolvingDynamicDimensions([1, 128, 3000])) - fillNDArray(&array, as: Float.self, with: floats) - return array -} - -// MARK: - Results - -private func printResults(tokens: [Int32], stepTimesMs: [Double]) async throws { - let avgMs = stepTimesMs.reduce(0, +) / Double(stepTimesMs.count) - print(String(format: " steps: %d", stepTimesMs.count)) - print(String(format: " latency: %.1f ms/tok", avgMs)) - print(String(format: " speed: %.1f tok/s", 1000.0 / avgMs)) - if let lo = stepTimesMs.min(), let hi = stepTimesMs.max() { - print(String(format: " min/max: %.1f / %.1f ms", lo, hi)) - } - print("\n── Transcription ──────────────────────────────────────────────────────") - let cacheBase = FileManager.default.homeDirectoryForCurrentUser - .appending(path: ".cache/huggingface/hub/models--openai--whisper-large-v3-turbo/snapshots") - guard - let snapshot = (try? FileManager.default.contentsOfDirectory(atPath: cacheBase.path))?.first, - let tokenizer = try? await AutoTokenizer.from(modelFolder: cacheBase.appending(path: snapshot)) - else { - throw RuntimeError("Tokenizer not found — run the model export once to populate the HF cache") - } - let ids = tokens.filter { $0 < 50257 }.map { Int($0) } - print(" \(tokenizer.decode(tokens: ids))") -} - -struct RuntimeError: Error, CustomStringConvertible { - let description: String - init(_ message: String) { description = message } -} - -extension Duration { - var inMilliseconds: Double { Double(components.seconds) * 1000 + Double(components.attoseconds) / 1e15 } -} +// MARK: - Split bundle via CoreAISpeech -// MARK: - Split runner (encoder + decoder with KV cache) - -func runSplit(bundleDir: String, audioPath: String?) async throws { +func runBundle(bundleURL: URL, audioPath: String?) async throws { print("Format: split (encoder + decoder, KV cache)") + let model = try await SpeechModel(bundleURL: bundleURL) - let encModel = try await AIModel(contentsOf: URL(fileURLWithPath: "\(bundleDir)/encoder.aimodel")) - let decModel = try await AIModel(contentsOf: URL(fileURLWithPath: "\(bundleDir)/decoder.aimodel")) - - guard let encFn = try encModel.loadFunction(named: "main"), - let decFn = try decModel.loadFunction(named: "main") - else { throw RuntimeError("No 'main' function in model") } - - let encDesc = encModel.functionDescriptor(for: "main")! - let decDesc = decModel.functionDescriptor(for: "main")! - - guard case .ndArray(let melNDDesc) = encDesc.inputDescriptor(of: "input_features"), - case .ndArray(let encOutNDDesc) = encDesc.outputDescriptor(of: "encoder_hidden_states") - else { throw RuntimeError("Unexpected encoder descriptors") } - - let encOutShape = encOutNDDesc.shape - - var melArray: NDArray if let path = audioPath { - melArray = try loadMelArray(from: path, descriptor: melNDDesc) + let url = URL(fileURLWithPath: path) + print("Transcribing \(url.lastPathComponent)…") + let t0 = ContinuousClock.now + let text = try await model.transcribe(audioURL: url) + let ms = (ContinuousClock.now - t0).inMilliseconds + print(String(format: " %.1f ms total", ms)) + print("\n── Transcription ──────────────────────────────────────────────────────") + print(" \(text)") } else { - print("No audio — using silence for benchmarking") - melArray = NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])) - fillNDArray(&melArray, as: Float.self, count: melElements) { _ in 0.0 } - } - var encOutArray = NDArray(descriptor: encOutNDDesc.resolvingDynamicDimensions(encOutShape)) - - // Warmup - do { - var out = InferenceFunction.MutableViews() - out.insert(&encOutArray, for: "encoder_hidden_states") - _ = try await encFn.run( - inputs: ["input_features": melArray], - states: InferenceFunction.MutableViews(), outputViews: consume out) - } - print("\n── Encoder ────────────────────────────────────────────────────────────") - let encT0 = ContinuousClock.now - do { - var out = InferenceFunction.MutableViews() - out.insert(&encOutArray, for: "encoder_hidden_states") - _ = try await encFn.run( - inputs: ["input_features": melArray], - states: InferenceFunction.MutableViews(), outputViews: consume out) - } - let encMs = (ContinuousClock.now - encT0).inMilliseconds - print(String(format: " latency: %.1f ms", encMs)) - - guard case .ndArray(let inputIdsNDDesc) = decDesc.inputDescriptor(of: "input_ids"), - case .ndArray(let posIdsNDDesc) = decDesc.inputDescriptor(of: "position_ids"), - case .ndArray(let encHSNDDesc) = decDesc.inputDescriptor(of: "encoder_hidden_states"), - case .ndArray(let keyCacheNDDesc) = decDesc.stateDescriptor(of: "keyCache"), - case .ndArray(let valCacheNDDesc) = decDesc.stateDescriptor(of: "valueCache"), - case .ndArray(let logitsNDDesc) = decDesc.outputDescriptor(of: "logits") - else { throw RuntimeError("Unexpected decoder descriptors") } - - let vocabSize = logitsNDDesc.shape.last! - let kcShape = keyCacheNDDesc.shape.map { $0 < 0 ? maxTargetPositions : $0 } - let vcShape = valCacheNDDesc.shape.map { $0 < 0 ? maxTargetPositions : $0 } - var keyCache = NDArray(descriptor: keyCacheNDDesc.resolvingDynamicDimensions(kcShape)) - var valueCache = NDArray(descriptor: valCacheNDDesc.resolvingDynamicDimensions(vcShape)) - - let encFlat = readNDArray(encOutArray, as: Float.self, count: encOutShape.reduce(1, *)) - var encHSArray = NDArray(descriptor: encHSNDDesc.resolvingDynamicDimensions(encOutShape)) - fillNDArray(&encHSArray, as: Float.self, with: encFlat) - var logitsArray = NDArray(descriptor: logitsNDDesc.resolvingDynamicDimensions([1, 1, vocabSize])) - - print("\n── Decoder ────────────────────────────────────────────────────────────") - - var tokens: [Int32] = forcedPrefix - var pos = 0 - for tok in forcedPrefix { - var ids = NDArray(descriptor: inputIdsNDDesc.resolvingDynamicDimensions([1, 1])) - var posIds = NDArray(descriptor: posIdsNDDesc.resolvingDynamicDimensions([1, pos + 1])) - fillNDArray(&ids, as: Int32.self, with: [tok]) - fillNDArray(&posIds, as: Int32.self, count: pos + 1) { Int32($0) } - var st = InferenceFunction.MutableViews() - st.insert(&keyCache, for: "keyCache") - st.insert(&valueCache, for: "valueCache") - var out = InferenceFunction.MutableViews() - out.insert(&logitsArray, for: "logits") - _ = try await decFn.run( - inputs: ["input_ids": ids, "position_ids": posIds, "encoder_hidden_states": encHSArray], - states: consume st, outputViews: consume out) - pos += 1 - } - - var stepTimesMs: [Double] = [] - while stepTimesMs.count < maxDecodeSteps { - var ids = NDArray(descriptor: inputIdsNDDesc.resolvingDynamicDimensions([1, 1])) - var posIds = NDArray(descriptor: posIdsNDDesc.resolvingDynamicDimensions([1, pos + 1])) - fillNDArray(&ids, as: Int32.self, with: [tokens.last!]) - fillNDArray(&posIds, as: Int32.self, count: pos + 1) { Int32($0) } - var st = InferenceFunction.MutableViews() - st.insert(&keyCache, for: "keyCache") - st.insert(&valueCache, for: "valueCache") - var out = InferenceFunction.MutableViews() - out.insert(&logitsArray, for: "logits") + print("No audio — silence benchmark") + let pcm = [Float](repeating: 0, count: 480_000) let t0 = ContinuousClock.now - _ = try await decFn.run( - inputs: ["input_ids": ids, "position_ids": posIds, "encoder_hidden_states": encHSArray], - states: consume st, outputViews: consume out) - stepTimesMs.append((ContinuousClock.now - t0).inMilliseconds) - let logits = flattenAsFloat(logitsArray) - let next = Int32(logits.indices.max(by: { logits[$0] < logits[$1] })!) - tokens.append(next) - pos += 1 - if next == eotToken { break } + _ = try await model.transcribe(pcm: pcm) + print(String(format: " %.1f ms (silence)", (ContinuousClock.now - t0).inMilliseconds)) } - - try await printResults(tokens: tokens, stepTimesMs: stepTimesMs) } -// MARK: - Legacy runner (monolithic model, no KV cache) +// MARK: - Legacy monolithic model func runLegacy(modelPath: String, audioPath: String?) async throws { print("Format: legacy (monolithic, no KV cache)") @@ -230,59 +77,88 @@ func runLegacy(modelPath: String, audioPath: String?) async throws { let vocabSize = logitsDesc.shape.last! let isStaticIds = !idsNDDesc.shape.contains(where: { $0 < 0 }) if isStaticIds { - print("decoder_input_ids exported with static shape — no past context per step") - print("Re-export with --mode legacy to fix") + print(" ⚠️ decoder_input_ids has static shape — no past context per step") } - // Warmup pass + var melArray: NDArray + if let path = audioPath { + let pcm = try MelSpectrogram.loadAndResample( + URL(fileURLWithPath: path), targetSampleRate: 16_000) + let floats = MelSpectrogram.fromPCM(pcm) + melArray = NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])) + fillNDArray(&melArray, as: Float.self, with: floats) + } else { + melArray = NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])) + fillNDArray(&melArray, as: Float.self, count: 128 * 3000) { _ in 0.0 } + } + + // Warmup do { var ids = NDArray(descriptor: idsNDDesc.resolvingDynamicDimensions([1, 1])) - fillNDArray(&ids, as: Int32.self, with: [forcedPrefix[0]]) - var logitsWarmup = NDArray(descriptor: logitsDesc.resolvingDynamicDimensions([1, 1, vocabSize])) + fillNDArray(&ids, as: Int32.self, with: [50258]) + var lw = NDArray(descriptor: logitsDesc.resolvingDynamicDimensions([1, 1, vocabSize])) var out = InferenceFunction.MutableViews() - out.insert(&logitsWarmup, for: "logits") + out.insert(&lw, for: "logits") _ = try await fn.run( - inputs: [ - "input_features": NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])), - "decoder_input_ids": ids, - ], + inputs: ["input_features": melArray, "decoder_input_ids": ids], states: InferenceFunction.MutableViews(), outputViews: consume out) } - var melArray: NDArray - if let path = audioPath { - melArray = try loadMelArray(from: path, descriptor: melNDDesc) - } else { - print("No audio — using silence for benchmarking") - melArray = NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])) - fillNDArray(&melArray, as: Float.self, count: melElements) { _ in 0.0 } - } + let config = GenerationConfig.whisper + var tokens: [Int32] = config.forcedPrefix + var stepTimesMs: [Double] = [] print("\n── Decode ─────────────────────────────────────────────────────────────") - var tokens: [Int32] = forcedPrefix - var stepTimesMs: [Double] = [] - - while stepTimesMs.count < maxDecodeSteps { + while stepTimesMs.count < config.maxDecodeSteps { let inputTokens: [Int32] = isStaticIds ? [tokens.last!] : tokens let seqLen = inputTokens.count var ids = NDArray(descriptor: idsNDDesc.resolvingDynamicDimensions([1, seqLen])) fillNDArray(&ids, as: Int32.self, with: inputTokens) - var logitsArray = NDArray(descriptor: logitsDesc.resolvingDynamicDimensions([1, seqLen, vocabSize])) + var la = NDArray(descriptor: logitsDesc.resolvingDynamicDimensions([1, seqLen, vocabSize])) var out = InferenceFunction.MutableViews() - out.insert(&logitsArray, for: "logits") + out.insert(&la, for: "logits") let t0 = ContinuousClock.now _ = try await fn.run( inputs: ["input_features": melArray, "decoder_input_ids": ids], states: InferenceFunction.MutableViews(), outputViews: consume out) stepTimesMs.append((ContinuousClock.now - t0).inMilliseconds) - let logits = flattenAsFloat(logitsArray) - let lastStart = (seqLen - 1) * vocabSize - let lastLogits = Array(logits[lastStart.. [Float] { - return fromPCM(try loadAndResample(url)) - } - - // MARK: - Audio loading + resampling - - static func loadAndResample(_ url: URL) throws -> [Float] { - let file = try AVAudioFile(forReading: url) - let fmt = AVAudioFormat( - commonFormat: .pcmFormatFloat32, - sampleRate: sampleRate, channels: 1, interleaved: false)! - guard let conv = AVAudioConverter(from: file.processingFormat, to: fmt) else { - throw NSError( - domain: "WhisperMel", code: 1, - userInfo: [ - NSLocalizedDescriptionKey: - "Cannot convert \(file.processingFormat) → 16 kHz mono" - ]) - } - let cap = AVAudioFrameCount( - ceil(Double(file.length) * sampleRate / file.processingFormat.sampleRate) + 1) - let out = AVAudioPCMBuffer(pcmFormat: fmt, frameCapacity: cap)! - var fed = false - var err: NSError? - conv.convert(to: out, error: &err) { _, status in - guard !fed else { - status.pointee = .endOfStream - return nil - } - fed = true - let buf = AVAudioPCMBuffer( - pcmFormat: file.processingFormat, - frameCapacity: AVAudioFrameCount(file.length))! - try? file.read(into: buf) - status.pointee = buf.frameLength > 0 ? .haveData : .endOfStream - return buf - } - if let e = err { throw e } - return Array( - UnsafeBufferPointer( - start: out.floatChannelData![0], - count: Int(out.frameLength))) - } - - // MARK: - Precomputed DFT basis (201 × 400) - // cos_basis[k, n] = cos(2π k n / 400) → Y[k].real = cos_basis @ x - // sin_basis[k, n] = -sin(2π k n / 400) → Y[k].imag = sin_basis @ x - - static let cosBasis: [Float] = { - var m = [Float](repeating: 0, count: (nFFT / 2 + 1) * nFFT) - for k in 0...nFFT / 2 { - for n in 0.. [Float] { - // 1. Trim / zero-pad to nSamples - var audio = raw - if audio.count > nSamples { - audio = Array(audio.prefix(nSamples)) - } else if audio.count < nSamples { - audio += [Float](repeating: 0, count: nSamples - audio.count) - } - - // 2. Reflect-pad by nFFT/2 (matches np.pad(..., mode='reflect')) - let pad = nFFT / 2 // 200 - var padded = [Float](repeating: 0, count: nSamples + 2 * pad) - for i in 0.. [Float] { - let fMax: Float = Float(sampleRate) / 2 // 8000 Hz - - func hzToMel(_ f: Float) -> Float { 2595 * log10(1 + f / 700) } - func melToHz(_ m: Float) -> Float { 700 * (pow(10, m / 2595) - 1) } - - let melMin = hzToMel(0) - let melMax = hzToMel(fMax) - let nPts = nMelBins + 2 - let pts = (0.. Float in - melToHz(melMin + Float(i) / Float(nPts - 1) * (melMax - melMin)) - } - // FFT bin frequencies for n_fft = 400 - let fftFreqs = (0..= fL && f <= fC { - fb[m * nFreqs + k] = norm * (f - fL) / (fC - fL) - } else if f > fC && f <= fR { - fb[m * nFreqs + k] = norm * (fR - f) / (fR - fC) - } - } - } - return fb - } -} From ff98ee3a94fa4ca562d6b9987809cdf4fbae4caf Mon Sep 17 00:00:00 2001 From: Carina Peng Date: Mon, 22 Jun 2026 18:24:30 -0700 Subject: [PATCH 7/7] Conform w other modalities --- swift/Sources/CoreAISpeech/SpeechModel.swift | 6 +++--- swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/swift/Sources/CoreAISpeech/SpeechModel.swift b/swift/Sources/CoreAISpeech/SpeechModel.swift index a61d28e..b211e2c 100644 --- a/swift/Sources/CoreAISpeech/SpeechModel.swift +++ b/swift/Sources/CoreAISpeech/SpeechModel.swift @@ -26,15 +26,15 @@ public actor SpeechModel { /// Load a model from a bundle directory. /// /// - Parameters: - /// - bundleURL: Directory containing encoder.aimodel and decoder.aimodel. + /// - url: Directory containing encoder.aimodel and decoder.aimodel. /// - decoder: Decode strategy. Defaults to ``WhisperDecoder``. /// - melConfig: Mel spectrogram parameters. Defaults to ``MelConfig/whisper``. public init( - bundleURL: URL, + resourcesAt url: URL, decoder: any SpeechDecoder = WhisperDecoder(), melConfig: MelConfig = .whisper ) async throws { - self.bundle = try await SpeechBundle(at: bundleURL) + self.bundle = try await SpeechBundle(at: url) self.decoder = decoder self.melConfig = melConfig try await warmUp() diff --git a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift index 14318d9..557ed92 100644 --- a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift +++ b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift @@ -39,7 +39,7 @@ struct SpeechRunner: AsyncParsableCommand { func runBundle(bundleURL: URL, audioPath: String?) async throws { print("Format: split (encoder + decoder, KV cache)") - let model = try await SpeechModel(bundleURL: bundleURL) + let model = try await SpeechModel(resourcesAt: bundleURL) if let path = audioPath { let url = URL(fileURLWithPath: path)