diff --git a/Package.swift b/Package.swift index 3a23d4c..5748c1d 100644 --- a/Package.swift +++ b/Package.swift @@ -29,6 +29,10 @@ let package = Package( "CoreAIImageSegmenter" ] ), + .library( + name: "CoreAISpeech", + targets: ["CoreAISpeech"] + ), .library( name: "CoreAIObjectDetection", targets: [ @@ -85,6 +89,19 @@ let package = Package( ] ), + // Speech recognition library + .target( + name: "CoreAISpeech", + dependencies: [ + "CoreAIShared", + .product(name: "Transformers", package: "swift-transformers"), + ], + path: "swift/Sources/CoreAISpeech", + swiftSettings: [ + .enableUpcomingFeature("MemberImportVisibility") + ] + ), + // Diffusion Pipeline .target( name: "CoreAIDiffusionPipeline", @@ -155,6 +172,17 @@ let package = Package( .enableUpcomingFeature("MemberImportVisibility") ] ), + .executableTarget( + name: "speech-runner", + dependencies: [ + "CoreAISpeech", + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ], + path: "swift/Sources/Tools/speech-runner", + swiftSettings: [ + .enableUpcomingFeature("MemberImportVisibility") + ] + ), // Public LLM Benchmark CLI (based on mlx-lm benchmark) .executableTarget( diff --git a/swift/Sources/CoreAISpeech/MelSpectrogram.swift b/swift/Sources/CoreAISpeech/MelSpectrogram.swift new file mode 100644 index 0000000..40d380c --- /dev/null +++ b/swift/Sources/CoreAISpeech/MelSpectrogram.swift @@ -0,0 +1,176 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import AVFoundation +import Accelerate +import CoreAIShared +import Foundation + +// MARK: - MelConfig + +/// Parameters for mel spectrogram computation. +public struct MelConfig: Sendable { + public let sampleRate: Double + public let nFFT: Int + public let hopLength: Int + public let nMelBins: Int + public let nFrames: Int + + public var nSamples: Int { Int(sampleRate) * (nFrames * hopLength / Int(sampleRate / 100)) } + + /// Whisper / Parakeet shared parameters. + public static let whisper = MelConfig( + sampleRate: 16_000, nFFT: 400, hopLength: 160, nMelBins: 128, nFrames: 3_000) +} + +// MARK: - MelSpectrogram + +/// Computes a mel spectrogram from an audio file or raw PCM samples. +public enum MelSpectrogram { + // MARK: Public API + + public static func fromFile(_ url: URL, config: MelConfig = .whisper) throws -> [Float] { + return fromPCM(try loadAndResample(url, targetSampleRate: config.sampleRate), config: config) + } + + public static func fromPCM(_ raw: [Float], config: MelConfig = .whisper) -> [Float] { + let nSamples = config.nFrames * config.hopLength + + var audio = raw + if audio.count > nSamples { + audio = Array(audio.prefix(nSamples)) + } else if audio.count < nSamples { + audio += [Float](repeating: 0, count: nSamples - audio.count) + } + + let pad = config.nFFT / 2 + var padded = [Float](repeating: 0, count: nSamples + 2 * pad) + for i in 0.. [Float] { + let file = try AVAudioFile(forReading: url) + let fmt = AVAudioFormat( + commonFormat: .pcmFormatFloat32, + sampleRate: targetSampleRate, channels: 1, interleaved: false)! + guard let conv = AVAudioConverter(from: file.processingFormat, to: fmt) else { + throw SpeechError.invalidAudio( + "Cannot resample \(file.processingFormat) to \(targetSampleRate) Hz mono") + } + let cap = AVAudioFrameCount( + ceil(Double(file.length) * targetSampleRate / file.processingFormat.sampleRate) + 1) + let out = AVAudioPCMBuffer(pcmFormat: fmt, frameCapacity: cap)! + var fed = false + var err: NSError? + conv.convert(to: out, error: &err) { _, status in + guard !fed else { + status.pointee = .endOfStream + return nil + } + fed = true + let buf = AVAudioPCMBuffer( + pcmFormat: file.processingFormat, + frameCapacity: AVAudioFrameCount(file.length))! + try? file.read(into: buf) + status.pointee = buf.frameLength > 0 ? .haveData : .endOfStream + return buf + } + if let e = err { throw SpeechError.invalidAudio(e.localizedDescription) } + return Array( + UnsafeBufferPointer( + start: out.floatChannelData![0], + count: Int(out.frameLength))) + } + + // MARK: Precomputed basis + + private static func hannWindow(size: Int) -> [Float] { + (0.. ([Float], [Float]) { + let nFreqs = nFFT / 2 + 1 + var cos = [Float](repeating: 0, count: nFreqs * nFFT) + var sin = [Float](repeating: 0, count: nFreqs * nFFT) + for k in 0.. [Float] { + let nFreqs = config.nFFT / 2 + 1 + let fMax = Float(config.sampleRate) / 2 + func h2m(_ f: Float) -> Float { 2595 * log10(1 + f / 700) } + func m2h(_ m: Float) -> Float { 700 * (pow(10, m / 2595) - 1) } + let pts = (0.. Float in + m2h(h2m(0) + Float(i) / Float(config.nMelBins + 1) * (h2m(fMax) - h2m(0))) + } + let fftFreqs = (0..= fL && f <= fC { + fb[m * nFreqs + k] = norm * (f - fL) / (fC - fL) + } else if f > fC && f <= fR { + fb[m * nFreqs + k] = norm * (fR - f) / (fR - fC) + } + } + } + return fb + } +} diff --git a/swift/Sources/CoreAISpeech/SpeechBundle.swift b/swift/Sources/CoreAISpeech/SpeechBundle.swift new file mode 100644 index 0000000..8cfa101 --- /dev/null +++ b/swift/Sources/CoreAISpeech/SpeechBundle.swift @@ -0,0 +1,125 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import AVFoundation +import CoreAI +import CoreAIShared +import Foundation +import Tokenizers + +// MARK: - SpeechBundle + +/// Locates and loads the assets inside a CoreAISpeech model bundle directory. +/// +/// A bundle directory contains: +/// encoder.aimodel — audio features → encoder hidden states +/// decoder.aimodel — autoregressive decoder with persistent state +/// generation_config.json (optional) — prefix, EOT token, etc. +/// +/// The tokenizer is loaded from the local HF cache if it can be found there. +public struct SpeechBundle: Sendable { + public let encoder: AIModel + public let decoder: AIModel + public let tokenizer: (any Tokenizer)? + public let generationConfig: GenerationConfig + + public init(at url: URL) async throws { + let encURL = url.appending(path: "encoder.aimodel") + let decURL = url.appending(path: "decoder.aimodel") + guard FileManager.default.fileExists(atPath: encURL.path), + FileManager.default.fileExists(atPath: decURL.path) + else { + throw SpeechError.missingModel( + "bundle at \(url.lastPathComponent) must contain encoder.aimodel and decoder.aimodel") + } + encoder = try await AIModel(contentsOf: encURL) + decoder = try await AIModel(contentsOf: decURL) + + // Load generation config from bundle if present, otherwise use Whisper defaults + let cfgURL = url.appending(path: "generation_config.json") + generationConfig = (try? GenerationConfig(from: cfgURL)) ?? .whisper + + // Tokenizer — look in bundle first, then fall back to HF cache + tokenizer = try? await Self.loadTokenizer(bundleURL: url, config: generationConfig) + } + + private static func loadTokenizer( + bundleURL: URL, config: GenerationConfig + ) async throws -> (any Tokenizer)? { + // 1. Try tokenizer files in the bundle itself + if FileManager.default.fileExists(atPath: bundleURL.appending(path: "tokenizer.json").path) { + return try? await AutoTokenizer.from(modelFolder: bundleURL) + } + // 2. Fall back to local HF cache using the model name from config + if let name = config.tokenizerName { + let cacheRoot = FileManager.default.homeDirectoryForCurrentUser + .appending(path: ".cache/huggingface/hub") + let folderName = "models--" + name.replacingOccurrences(of: "/", with: "--") + let snapshotsDir = cacheRoot.appending(path: "\(folderName)/snapshots") + if let snapshot = try? FileManager.default.contentsOfDirectory( + atPath: snapshotsDir.path + ).first { + return try? await AutoTokenizer.from( + modelFolder: snapshotsDir.appending(path: snapshot)) + } + } + return nil + } +} + +// MARK: - GenerationConfig + +/// Model-specific generation parameters, read from generation_config.json in the bundle. +public struct GenerationConfig: Sendable { + /// Tokens prepended to every decode sequence before free generation. + public let forcedPrefix: [Int32] + /// Token that signals end of transcription. + public let eotToken: Int32 + /// Maximum tokens to generate per call. + public let maxDecodeSteps: Int + /// HuggingFace model name for loading the tokenizer from cache. + public let tokenizerName: String? + + /// Whisper large-v3-turbo defaults. + public static let whisper = GenerationConfig( + forcedPrefix: [50258, 50259, 50360, 50364], // BOS <|en|> <|transcribe|> <|notimestamps|> + eotToken: 50257, + maxDecodeSteps: 50, + tokenizerName: "openai/whisper-large-v3-turbo" + ) + + init(forcedPrefix: [Int32], eotToken: Int32, maxDecodeSteps: Int, tokenizerName: String?) { + self.forcedPrefix = forcedPrefix + self.eotToken = eotToken + self.maxDecodeSteps = maxDecodeSteps + self.tokenizerName = tokenizerName + } + + init(from url: URL) throws { + let data = try Data(contentsOf: url) + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] ?? [:] + forcedPrefix = (json["forced_decoder_ids"] as? [Int]).map { $0.map { Int32($0) } } ?? Self.whisper.forcedPrefix + eotToken = (json["eos_token_id"] as? Int).map { Int32($0) } ?? Self.whisper.eotToken + maxDecodeSteps = (json["max_new_tokens"] as? Int) ?? Self.whisper.maxDecodeSteps + tokenizerName = json["tokenizer_name"] as? String ?? Self.whisper.tokenizerName + } +} + +// MARK: - SpeechError + +public enum SpeechError: Error, CustomStringConvertible { + case missingModel(String) + case missingTokenizer + case invalidAudio(String) + + public var description: String { + switch self { + case .missingModel(let msg): return "Missing model: \(msg)" + case .missingTokenizer: + return "Tokenizer not found — ensure the model bundle includes a tokenizer or the HF cache is populated" + case .invalidAudio(let msg): return "Invalid audio: \(msg)" + } + } +} diff --git a/swift/Sources/CoreAISpeech/SpeechDecoder.swift b/swift/Sources/CoreAISpeech/SpeechDecoder.swift new file mode 100644 index 0000000..f8a2056 --- /dev/null +++ b/swift/Sources/CoreAISpeech/SpeechDecoder.swift @@ -0,0 +1,97 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import CoreAI +import CoreAIShared +import Foundation + +// MARK: - SpeechDecoder protocol + +/// Model-specific decode logic. +public protocol SpeechDecoder: Sendable { + func decode( + encoderOutput: NDArray, + encoderOutputShape: [Int], + decoderModel: AIModel, + config: GenerationConfig + ) async throws -> [Int32] +} + +// MARK: - WhisperDecoder + +/// Greedy decoder for Whisper (encoder-decoder, cross-attention, KV cache). +public struct WhisperDecoder: SpeechDecoder { + public init() {} + + public func decode( + encoderOutput: NDArray, + encoderOutputShape: [Int], + decoderModel: AIModel, + config: GenerationConfig + ) async throws -> [Int32] { + guard let decFn = try decoderModel.loadFunction(named: "main") else { + throw SpeechError.missingModel("No 'main' function in decoder") + } + let decDesc = decoderModel.functionDescriptor(for: "main")! + + guard case .ndArray(let inputIdsNDDesc) = decDesc.inputDescriptor(of: "input_ids"), + case .ndArray(let posIdsNDDesc) = decDesc.inputDescriptor(of: "position_ids"), + case .ndArray(let encHSNDDesc) = decDesc.inputDescriptor(of: "encoder_hidden_states"), + case .ndArray(let keyCacheNDDesc) = decDesc.stateDescriptor(of: "keyCache"), + case .ndArray(let valCacheNDDesc) = decDesc.stateDescriptor(of: "valueCache"), + case .ndArray(let logitsNDDesc) = decDesc.outputDescriptor(of: "logits") + else { throw SpeechError.missingModel("Unexpected decoder descriptors") } + + let vocabSize = logitsNDDesc.shape.last! + let maxTargetPos = 448 + let kcShape = keyCacheNDDesc.shape.map { $0 < 0 ? maxTargetPos : $0 } + let vcShape = valCacheNDDesc.shape.map { $0 < 0 ? maxTargetPos : $0 } + var keyCache = NDArray(descriptor: keyCacheNDDesc.resolvingDynamicDimensions(kcShape)) + var valueCache = NDArray(descriptor: valCacheNDDesc.resolvingDynamicDimensions(vcShape)) + + var encHSArray = NDArray(descriptor: encHSNDDesc.resolvingDynamicDimensions(encoderOutputShape)) + let encFlat = readNDArray(encoderOutput, as: Float.self, count: encoderOutputShape.reduce(1, *)) + fillNDArray(&encHSArray, as: Float.self, with: encFlat) + + var logitsArray = NDArray(descriptor: logitsNDDesc.resolvingDynamicDimensions([1, 1, vocabSize])) + + func step(_ tok: Int32, pos: Int) async throws { + var ids = NDArray(descriptor: inputIdsNDDesc.resolvingDynamicDimensions([1, 1])) + var posIds = NDArray(descriptor: posIdsNDDesc.resolvingDynamicDimensions([1, pos + 1])) + fillNDArray(&ids, as: Int32.self, with: [tok]) + fillNDArray(&posIds, as: Int32.self, count: pos + 1) { Int32($0) } + var st = InferenceFunction.MutableViews() + st.insert(&keyCache, for: "keyCache") + st.insert(&valueCache, for: "valueCache") + var out = InferenceFunction.MutableViews() + out.insert(&logitsArray, for: "logits") + _ = try await decFn.run( + inputs: [ + "input_ids": ids, "position_ids": posIds, + "encoder_hidden_states": encHSArray, + ], + states: consume st, outputViews: consume out) + } + + // Prime KV cache with forced prefix + var tokens: [Int32] = config.forcedPrefix + for (i, tok) in config.forcedPrefix.enumerated() { + try await step(tok, pos: i) + } + + // Greedy decode + var pos = config.forcedPrefix.count + while tokens.count - config.forcedPrefix.count < config.maxDecodeSteps { + try await step(tokens.last!, pos: pos) + let logits = flattenAsFloat(logitsArray) + let next = Int32(logits.indices.max(by: { logits[$0] < logits[$1] })!) + tokens.append(next) + pos += 1 + if next == config.eotToken { break } + } + + return tokens + } +} diff --git a/swift/Sources/CoreAISpeech/SpeechModel.swift b/swift/Sources/CoreAISpeech/SpeechModel.swift new file mode 100644 index 0000000..b211e2c --- /dev/null +++ b/swift/Sources/CoreAISpeech/SpeechModel.swift @@ -0,0 +1,130 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import CoreAI +import CoreAIShared +import Foundation +import Tokenizers + +// MARK: - SpeechModel + +/// On-device speech recognition model. +/// +/// Loads a CoreAISpeech bundle (encoder.aimodel + decoder.aimodel) and transcribes +/// audio files. The decoder architecture is pluggable via ``SpeechDecoder`` +public actor SpeechModel { + private let bundle: SpeechBundle + private let decoder: any SpeechDecoder + private let melConfig: MelConfig + + // Encoder function and descriptor, cached after first load + private var encFn: InferenceFunction? + private var encOutShape: [Int]? + + /// Load a model from a bundle directory. + /// + /// - Parameters: + /// - url: Directory containing encoder.aimodel and decoder.aimodel. + /// - decoder: Decode strategy. Defaults to ``WhisperDecoder``. + /// - melConfig: Mel spectrogram parameters. Defaults to ``MelConfig/whisper``. + public init( + resourcesAt url: URL, + decoder: any SpeechDecoder = WhisperDecoder(), + melConfig: MelConfig = .whisper + ) async throws { + self.bundle = try await SpeechBundle(at: url) + self.decoder = decoder + self.melConfig = melConfig + try await warmUp() + } + + // MARK: - Transcription + + /// Transcribe an audio file, returning the full text. + public func transcribe(audioURL: URL) async throws -> String { + let tokens = try await decodeAudio(from: audioURL) + return try detokenize(tokens) + } + + /// Transcribe raw 16 kHz mono PCM samples. + public func transcribe(pcm: [Float]) async throws -> String { + let tokens = try await decodeAudio(pcm: pcm) + return try detokenize(tokens) + } + + // MARK: - Internals + + private func warmUp() async throws { + // Run the encoder once with silence to trigger JIT compilation + guard let fn = try bundle.encoder.loadFunction(named: "main") else { + throw SpeechError.missingModel("No 'main' function in encoder") + } + encFn = fn + let encDesc = bundle.encoder.functionDescriptor(for: "main")! + guard case .ndArray(let encOutNDDesc) = encDesc.outputDescriptor(of: "encoder_hidden_states") + else { throw SpeechError.missingModel("Unexpected encoder output descriptor") } + encOutShape = encOutNDDesc.shape + + guard case .ndArray(let melNDDesc) = encDesc.inputDescriptor(of: "input_features") + else { throw SpeechError.missingModel("Unexpected encoder input descriptor") } + + var silence = NDArray( + descriptor: melNDDesc.resolvingDynamicDimensions([1, melConfig.nMelBins, melConfig.nFrames])) + fillNDArray(&silence, as: Float.self, count: melConfig.nMelBins * melConfig.nFrames) { _ in 0.0 } + var encOut = NDArray(descriptor: encOutNDDesc.resolvingDynamicDimensions(encOutNDDesc.shape)) + var out = InferenceFunction.MutableViews() + out.insert(&encOut, for: "encoder_hidden_states") + _ = try await fn.run( + inputs: ["input_features": silence], + states: InferenceFunction.MutableViews(), outputViews: consume out) + } + + private func runEncoder(_ melArray: inout NDArray) async throws -> NDArray { + guard let fn = encFn, let shape = encOutShape else { + throw SpeechError.missingModel("Encoder not initialised") + } + let encDesc = bundle.encoder.functionDescriptor(for: "main")! + guard case .ndArray(let encOutNDDesc) = encDesc.outputDescriptor(of: "encoder_hidden_states") + else { throw SpeechError.missingModel("Unexpected encoder output") } + var encOut = NDArray(descriptor: encOutNDDesc.resolvingDynamicDimensions(shape)) + var out = InferenceFunction.MutableViews() + out.insert(&encOut, for: "encoder_hidden_states") + _ = try await fn.run( + inputs: ["input_features": melArray], + states: InferenceFunction.MutableViews(), outputViews: consume out) + return encOut + } + + private func decodeAudio(from url: URL) async throws -> [Int32] { + let pcm = try MelSpectrogram.loadAndResample(url, targetSampleRate: melConfig.sampleRate) + return try await decodeAudio(pcm: pcm) + } + + private func decodeAudio(pcm: [Float]) async throws -> [Int32] { + let encDesc = bundle.encoder.functionDescriptor(for: "main")! + guard case .ndArray(let melNDDesc) = encDesc.inputDescriptor(of: "input_features") + else { throw SpeechError.missingModel("Unexpected encoder input") } + + let floats = MelSpectrogram.fromPCM(pcm, config: melConfig) + var melArray = NDArray( + descriptor: melNDDesc.resolvingDynamicDimensions( + [1, melConfig.nMelBins, melConfig.nFrames])) + fillNDArray(&melArray, as: Float.self, with: floats) + + let encoderOutput = try await runEncoder(&melArray) + let shape = encOutShape ?? [1, 1500, 1280] + return try await decoder.decode( + encoderOutput: encoderOutput, + encoderOutputShape: shape, + decoderModel: bundle.decoder, + config: bundle.generationConfig) + } + + private func detokenize(_ tokens: [Int32]) throws -> String { + guard let tokenizer = bundle.tokenizer else { throw SpeechError.missingTokenizer } + let ids = tokens.filter { $0 < bundle.generationConfig.eotToken }.map { Int($0) } + return tokenizer.decode(tokens: ids).trimmingCharacters(in: .whitespaces) + } +} diff --git a/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift new file mode 100644 index 0000000..557ed92 --- /dev/null +++ b/swift/Sources/Tools/speech-runner/SpeechRunnerMain.swift @@ -0,0 +1,164 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import ArgumentParser +import CoreAI +import CoreAIShared +import CoreAISpeech +import Foundation +import Tokenizers + +// MARK: - Entry point + +@main +struct SpeechRunner: AsyncParsableCommand { + static let configuration = CommandConfiguration( + commandName: "speech-runner", + abstract: "Transcribe audio using a CoreAI speech model bundle" + ) + + @Argument(help: "Bundle dir (encoder.aimodel + decoder.aimodel) or single .aimodel (legacy)") + var modelPath: String + + @Argument(help: "Audio file (wav, flac, m4a, …). Omit for latency benchmarking with silence.") + var audioPath: String? + + func run() async throws { + let bundleURL = URL(fileURLWithPath: modelPath) + if FileManager.default.fileExists(atPath: bundleURL.appending(path: "encoder.aimodel").path) { + try await runBundle(bundleURL: bundleURL, audioPath: audioPath) + } else { + try await runLegacy(modelPath: modelPath, audioPath: audioPath) + } + } +} + +// MARK: - Split bundle via CoreAISpeech + +func runBundle(bundleURL: URL, audioPath: String?) async throws { + print("Format: split (encoder + decoder, KV cache)") + let model = try await SpeechModel(resourcesAt: bundleURL) + + if let path = audioPath { + let url = URL(fileURLWithPath: path) + print("Transcribing \(url.lastPathComponent)…") + let t0 = ContinuousClock.now + let text = try await model.transcribe(audioURL: url) + let ms = (ContinuousClock.now - t0).inMilliseconds + print(String(format: " %.1f ms total", ms)) + print("\n── Transcription ──────────────────────────────────────────────────────") + print(" \(text)") + } else { + print("No audio — silence benchmark") + let pcm = [Float](repeating: 0, count: 480_000) + let t0 = ContinuousClock.now + _ = try await model.transcribe(pcm: pcm) + print(String(format: " %.1f ms (silence)", (ContinuousClock.now - t0).inMilliseconds)) + } +} + +// MARK: - Legacy monolithic model + +func runLegacy(modelPath: String, audioPath: String?) async throws { + print("Format: legacy (monolithic, no KV cache)") + + let model = try await AIModel(contentsOf: URL(fileURLWithPath: modelPath)) + guard let fn = try model.loadFunction(named: "main") + else { throw RuntimeError("No 'main' function in model") } + let desc = model.functionDescriptor(for: "main")! + + guard case .ndArray(let melNDDesc) = desc.inputDescriptor(of: "input_features"), + case .ndArray(let idsNDDesc) = desc.inputDescriptor(of: "decoder_input_ids"), + case .ndArray(let logitsDesc) = desc.outputDescriptor(of: "logits") + else { throw RuntimeError("Unexpected model descriptors") } + + let vocabSize = logitsDesc.shape.last! + let isStaticIds = !idsNDDesc.shape.contains(where: { $0 < 0 }) + if isStaticIds { + print(" ⚠️ decoder_input_ids has static shape — no past context per step") + } + + var melArray: NDArray + if let path = audioPath { + let pcm = try MelSpectrogram.loadAndResample( + URL(fileURLWithPath: path), targetSampleRate: 16_000) + let floats = MelSpectrogram.fromPCM(pcm) + melArray = NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])) + fillNDArray(&melArray, as: Float.self, with: floats) + } else { + melArray = NDArray(descriptor: melNDDesc.resolvingDynamicDimensions([1, 128, 3000])) + fillNDArray(&melArray, as: Float.self, count: 128 * 3000) { _ in 0.0 } + } + + // Warmup + do { + var ids = NDArray(descriptor: idsNDDesc.resolvingDynamicDimensions([1, 1])) + fillNDArray(&ids, as: Int32.self, with: [50258]) + var lw = NDArray(descriptor: logitsDesc.resolvingDynamicDimensions([1, 1, vocabSize])) + var out = InferenceFunction.MutableViews() + out.insert(&lw, for: "logits") + _ = try await fn.run( + inputs: ["input_features": melArray, "decoder_input_ids": ids], + states: InferenceFunction.MutableViews(), outputViews: consume out) + } + + let config = GenerationConfig.whisper + var tokens: [Int32] = config.forcedPrefix + var stepTimesMs: [Double] = [] + + print("\n── Decode ─────────────────────────────────────────────────────────────") + + while stepTimesMs.count < config.maxDecodeSteps { + let inputTokens: [Int32] = isStaticIds ? [tokens.last!] : tokens + let seqLen = inputTokens.count + var ids = NDArray(descriptor: idsNDDesc.resolvingDynamicDimensions([1, seqLen])) + fillNDArray(&ids, as: Int32.self, with: inputTokens) + var la = NDArray(descriptor: logitsDesc.resolvingDynamicDimensions([1, seqLen, vocabSize])) + var out = InferenceFunction.MutableViews() + out.insert(&la, for: "logits") + let t0 = ContinuousClock.now + _ = try await fn.run( + inputs: ["input_features": melArray, "decoder_input_ids": ids], + states: InferenceFunction.MutableViews(), outputViews: consume out) + stepTimesMs.append((ContinuousClock.now - t0).inMilliseconds) + let logits = flattenAsFloat(la) + let base = (seqLen - 1) * vocabSize + let next = Int32( + (0..