diff --git a/Scripts/apply-mlx-patches.sh b/Scripts/apply-mlx-patches.sh index 9aed1155..36c735b0 100755 --- a/Scripts/apply-mlx-patches.sh +++ b/Scripts/apply-mlx-patches.sh @@ -19,9 +19,9 @@ log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } log_error() { echo -e "${RED}[ERROR]${NC} $1"; } -PATCH_FILES=("Qwen3VL.swift" "Qwen3Next.swift" "GatedDelta.swift" "Qwen3_5MoE.swift" "DeepseekV3.swift" "MiniMaxM2.swift" "NemotronH.swift" "GLM4MoeLite.swift" "GLM5MoeDsa.swift" "KimiK25.swift" "Gemma4Text.swift" "Gemma4VLM.swift" "LLMModelFactory.swift" "Load.swift" "Evaluate.swift" "LanguageModel.swift" "Tokenizer.swift" "Qwen3_5MoEVL.swift" "VLMModelFactory.swift" "SamplerTests.swift" "ToolCallFormat.swift" "KVCache.swift" "SwitchLayers.swift" "BatchKVCache.swift" "SSM.swift" "Chat.swift" "Gemma4FunctionParser.swift") -TARGET_PATHS=("Libraries/MLXVLM/Models/Qwen3VL.swift" "Libraries/MLXLLM/Models/Qwen3Next.swift" "Libraries/MLXLLM/Models/GatedDelta.swift" "Libraries/MLXLLM/Models/Qwen3_5MoE.swift" "Libraries/MLXLLM/Models/DeepseekV3.swift" "Libraries/MLXLLM/Models/MiniMaxM2.swift" "Libraries/MLXLLM/Models/NemotronH.swift" "Libraries/MLXLLM/Models/GLM4MoeLite.swift" "Libraries/MLXLLM/Models/GLM5MoeDsa.swift" "Libraries/MLXLLM/Models/KimiK25.swift" "Libraries/MLXLLM/Models/Gemma4Text.swift" "Libraries/MLXVLM/Models/Gemma4VLM.swift" "Libraries/MLXLLM/LLMModelFactory.swift" "Libraries/MLXLMCommon/Load.swift" "Libraries/MLXLMCommon/Evaluate.swift" "Libraries/MLXLMCommon/LanguageModel.swift" "Libraries/MLXLMCommon/Tokenizer.swift" "Libraries/MLXVLM/Models/Qwen3_5MoEVL.swift" "Libraries/MLXVLM/VLMModelFactory.swift" "Tests/MLXLMTests/SamplerTests.swift" "Libraries/MLXLMCommon/Tool/ToolCallFormat.swift" "Libraries/MLXLMCommon/KVCache.swift" "Libraries/MLXLMCommon/SwitchLayers.swift" "Libraries/MLXLMCommon/BatchKVCache.swift" "Libraries/MLXLLM/Models/SSM.swift" "Libraries/MLXLMCommon/Chat.swift" "Libraries/MLXLMCommon/Tool/Parsers/Gemma4FunctionParser.swift") -NEW_FILES=("Qwen3Next.swift" "GatedDelta.swift" "Qwen3_5MoE.swift" "MiniMaxM2.swift" "NemotronH.swift" "GLM4MoeLite.swift" "GLM5MoeDsa.swift" "KimiK25.swift" "Gemma4Text.swift" "Gemma4VLM.swift" "Qwen3_5MoEVL.swift" "SamplerTests.swift" "BatchKVCache.swift" "Gemma4FunctionParser.swift") +PATCH_FILES=("Qwen3VL.swift" "Qwen3Next.swift" "GatedDelta.swift" "Qwen3_5MoE.swift" "DeepseekV3.swift" "MiniMaxM2.swift" "NemotronH.swift" "GLM4MoeLite.swift" "GLM5MoeDsa.swift" "KimiK25.swift" "Gemma4Text.swift" "Gemma4VLM.swift" "LLMModelFactory.swift" "Load.swift" "Evaluate.swift" "LanguageModel.swift" "Tokenizer.swift" "AttentionUtils.swift" "Qwen3_5MoEVL.swift" "VLMModelFactory.swift" "SamplerTests.swift" "ToolCallFormat.swift" "KVCache.swift" "SwitchLayers.swift" "BatchKVCache.swift" "SSM.swift" "Chat.swift" "Gemma4FunctionParser.swift") +TARGET_PATHS=("Libraries/MLXVLM/Models/Qwen3VL.swift" "Libraries/MLXLLM/Models/Qwen3Next.swift" "Libraries/MLXLLM/Models/GatedDelta.swift" "Libraries/MLXLLM/Models/Qwen3_5MoE.swift" "Libraries/MLXLLM/Models/DeepseekV3.swift" "Libraries/MLXLLM/Models/MiniMaxM2.swift" "Libraries/MLXLLM/Models/NemotronH.swift" "Libraries/MLXLLM/Models/GLM4MoeLite.swift" "Libraries/MLXLLM/Models/GLM5MoeDsa.swift" "Libraries/MLXLLM/Models/KimiK25.swift" "Libraries/MLXLLM/Models/Gemma4Text.swift" "Libraries/MLXVLM/Models/Gemma4VLM.swift" "Libraries/MLXLLM/LLMModelFactory.swift" "Libraries/MLXLMCommon/Load.swift" "Libraries/MLXLMCommon/Evaluate.swift" "Libraries/MLXLMCommon/LanguageModel.swift" "Libraries/MLXLMCommon/Tokenizer.swift" "Libraries/MLXLMCommon/AttentionUtils.swift" "Libraries/MLXVLM/Models/Qwen3_5MoEVL.swift" "Libraries/MLXVLM/VLMModelFactory.swift" "Tests/MLXLMTests/SamplerTests.swift" "Libraries/MLXLMCommon/Tool/ToolCallFormat.swift" "Libraries/MLXLMCommon/KVCache.swift" "Libraries/MLXLMCommon/SwitchLayers.swift" "Libraries/MLXLMCommon/BatchKVCache.swift" "Libraries/MLXLLM/Models/SSM.swift" "Libraries/MLXLMCommon/Chat.swift" "Libraries/MLXLMCommon/Tool/Parsers/Gemma4FunctionParser.swift") +NEW_FILES=("Qwen3Next.swift" "GatedDelta.swift" "Qwen3_5MoE.swift" "MiniMaxM2.swift" "NemotronH.swift" "GLM4MoeLite.swift" "GLM5MoeDsa.swift" "KimiK25.swift" "Gemma4Text.swift" "Gemma4VLM.swift" "AttentionUtils.swift" "Qwen3_5MoEVL.swift" "SamplerTests.swift" "BatchKVCache.swift" "Gemma4FunctionParser.swift") # --- Package-level patches (sed replacements in Package.swift) --- # Each entry: "search_pattern|replacement" diff --git a/Scripts/patches/AttentionUtils.swift b/Scripts/patches/AttentionUtils.swift new file mode 100644 index 00000000..561301aa --- /dev/null +++ b/Scripts/patches/AttentionUtils.swift @@ -0,0 +1,95 @@ +import Foundation +import MLX + +/// Attention utilities that match Python mlx-lm's interface +/// +/// This provides a single function that automatically routes to TurboQuant, quantized, or regular +/// attention based on cache type, matching Python's `scaled_dot_product_attention` + +/// Automatic attention with cache update +/// +/// This function matches Python's `scaled_dot_product_attention` in base.py: +/// - Detects if cache is `TurboQuantKVCache` or `QuantizedKVCache` using `isinstance` pattern +/// - Routes to TurboQuant, quantized, or `MLXFast.scaledDotProductAttention` +/// - Handles cache updating automatically +/// - Transparent to models - they just call this function +/// +/// **Usage in models:** +/// ```swift +/// let output = attentionWithCacheUpdate( +/// queries: queries, +/// keys: keys, +/// values: values, +/// cache: cache, +/// scale: scale, +/// mask: mask +/// ) +/// ``` +/// +/// - Parameters: +/// - queries: Query tensor [B, nHeads, L, D] +/// - keys: Raw key tensor to be cached [B, nKVHeads, L, D] +/// - values: Raw value tensor to be cached [B, nKVHeads, L, D] +/// - cache: Cache instance (any type) +/// - scale: Attention scale factor +/// - mask: Attention mask +/// - Returns: Attention output [B, nHeads, L, D] +public func attentionWithCacheUpdate( + queries: MLXArray, + keys: MLXArray, + values: MLXArray, + cache: KVCache?, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode = .none +) -> MLXArray { + guard let cache else { + return MLXFast.scaledDotProductAttention( + queries: queries, + keys: keys, + values: values, + scale: scale, + mask: mask + ) + } + if let turboQuantKVCache = cache as? TurboQuantKVCacheProtocol { + if queries.dim(2) == 1 { + return turboQuantKVCache.decodeAttention( + queries: queries, + keys: keys, + values: values, + scale: scale, + mask: mask + ) + } else { + return turboQuantKVCache.prefillAttention( + queries: queries, + keys: keys, + values: values, + scale: scale, + mask: mask + ) + } + } else if let quantizedKVCache = cache as? QuantizedKVCacheProtocol { + let (quantizedKeys, quantizedValues) = quantizedKVCache.updateQuantized( + keys: keys, values: values) + return quantizedScaledDotProductAttention( + queries: queries, + quantizedKeys: quantizedKeys, + quantizedValues: quantizedValues, + scale: scale, + mask: mask, + groupSize: quantizedKVCache.groupSize, + bits: quantizedKVCache.bits, + mode: quantizedKVCache.mode + ) + } else { + let (cachedKeys, cachedValues) = cache.update(keys: keys, values: values) + return MLXFast.scaledDotProductAttention( + queries: queries, + keys: cachedKeys, + values: cachedValues, + scale: scale, + mask: mask + ) + } +} diff --git a/Scripts/patches/KVCache.swift b/Scripts/patches/KVCache.swift index 8ab8a68c..71f8fd9e 100644 --- a/Scripts/patches/KVCache.swift +++ b/Scripts/patches/KVCache.swift @@ -330,6 +330,20 @@ public struct TurboQuantMetadataArtifact: Codable, Equatable, Sendable { public protocol TurboQuantKVCacheProtocol: KVCache { var configuration: TurboQuantConfiguration { get } + func decodeAttention( + queries: MLXArray, + keys: MLXArray, + values: MLXArray, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode + ) -> MLXArray + func prefillAttention( + queries: MLXArray, + keys: MLXArray, + values: MLXArray, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode + ) -> MLXArray } public func loadTurboQuantMetadata(url: URL) throws -> TurboQuantMetadataArtifact { @@ -344,6 +358,436 @@ public func saveTurboQuantMetadata(_ metadata: TurboQuantMetadataArtifact, url: try data.write(to: url, options: .atomic) } +private let turboQuantDefaultSeed = 0 +private let turboQuantEpsilon: Float = 1e-6 + +private struct TurboQuantSplitMix64: RandomNumberGenerator { + private var state: UInt64 + + init(seed: UInt64) { + self.state = seed &+ 0x9E37_79B9_7F4A_7C15 + } + + mutating func next() -> UInt64 { + state &+= 0x9E37_79B9_7F4A_7C15 + var z = state + z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9 + z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB + return z ^ (z >> 31) + } +} + +private func validateTurboQuantBits(_ bits: Float) -> Float { + precondition(bits >= 1, "TurboQuant requires kvBits >= 1.") + let rounded = (bits * 2).rounded() / 2 + precondition( + abs(bits - rounded) < 0.000_1, + "TurboQuant currently supports integer and .5 bit-widths." + ) + return rounded +} + +private func turboQuantKeyBits(for bits: Float) -> Int { + Int(floor(validateTurboQuantBits(bits))) +} + +private func turboQuantValueBits(for bits: Float) -> Int { + Int(ceil(validateTurboQuantBits(bits))) +} + +private func turboQuantPackedWidth(length: Int, bits: Int) -> Int { + guard length > 0, bits > 0 else { return 0 } + return (length * bits + 31) / 32 +} + +private func turboQuantIsPowerOfTwo(_ value: Int) -> Bool { + value > 0 && (value & (value - 1)) == 0 +} + +private enum TurboQuantTableCache { + static let lock = NSLock() + nonisolated(unsafe) static var signVectors: [String: MLXArray] = [:] + nonisolated(unsafe) static var codebooks: [String: MLXArray] = [:] + nonisolated(unsafe) static var midpoints: [String: [Float]] = [:] +} + +private func turboQuantSignVector(dim: Int, seed: Int) -> MLXArray { + let cacheKey = "\(dim):\(seed)" + TurboQuantTableCache.lock.lock() + if let cached = TurboQuantTableCache.signVectors[cacheKey] { + TurboQuantTableCache.lock.unlock() + return cached + } + TurboQuantTableCache.lock.unlock() + + var generator = TurboQuantSplitMix64(seed: UInt64(seed &+ dim &* 7_919)) + let values = (0 ..< dim).map { _ in ((generator.next() & 1) == 0) ? Float(-1) : Float(1) } + let signs = MLXArray(values) + + TurboQuantTableCache.lock.lock() + TurboQuantTableCache.signVectors[cacheKey] = signs + TurboQuantTableCache.lock.unlock() + return signs +} + +private func turboQuantBetaPDF(grid: [Double], dim: Int) -> [Double] { + guard dim > 1 else { + let weight = 1.0 / Double(max(grid.count, 1)) + return Array(repeating: weight, count: grid.count) + } + + let logCoeff = + Foundation.lgamma(Double(dim) / 2) + - 0.5 * Foundation.log(Double.pi) + - Foundation.lgamma(Double(dim - 1) / 2) + + var logPDF: [Double] = [] + logPDF.reserveCapacity(grid.count) + var maxLog = -Double.infinity + for value in grid { + let inner = max(1.0 - value * value, 1e-30) + let next = logCoeff + (Double(dim - 3) / 2) * Foundation.log(inner) + logPDF.append(next) + maxLog = max(maxLog, next) + } + + var weights: [Double] = [] + weights.reserveCapacity(grid.count) + var total = 0.0 + for value in logPDF { + let weight = Foundation.exp(value - maxLog) + weights.append(weight) + total += weight + } + + guard total > 0 else { + let weight = 1.0 / Double(max(grid.count, 1)) + return Array(repeating: weight, count: grid.count) + } + + return weights.map { $0 / total } +} + +private func turboQuantInterpolate(grid: [Double], cdf: [Double], quantile: Double) -> Double { + guard !grid.isEmpty else { return 0 } + if quantile <= cdf[0] { return grid[0] } + if quantile >= cdf[cdf.count - 1] { return grid[grid.count - 1] } + + var low = 0 + var high = cdf.count - 1 + while low < high { + let mid = (low + high) / 2 + if cdf[mid] < quantile { + low = mid + 1 + } else { + high = mid + } + } + + let upper = low + let lower = max(upper - 1, 0) + let lowerCDF = cdf[lower] + let upperCDF = cdf[upper] + let t = + if abs(upperCDF - lowerCDF) < 1e-12 { + 0.0 + } else { + (quantile - lowerCDF) / (upperCDF - lowerCDF) + } + return grid[lower] + (grid[upper] - grid[lower]) * t +} + +private func turboQuantBuildCodebook(dim: Int, bits: Int) -> [Float] { + guard bits > 0 else { return [] } + let levels = 1 << bits + + if dim <= 1 { + guard levels > 1 else { return [0] } + return (0 ..< levels).map { index in + -1 + 2 * Float(index) / Float(levels - 1) + } + } + + let gridCount = 4_096 + let lower = -1.0 + 1e-6 + let upper = 1.0 - 1e-6 + let step = (upper - lower) / Double(gridCount - 1) + let grid = (0 ..< gridCount).map { lower + Double($0) * step } + let weights = turboQuantBetaPDF(grid: grid, dim: dim) + + var cdf: [Double] = [] + cdf.reserveCapacity(weights.count) + var running = 0.0 + for weight in weights { + running += weight + cdf.append(running) + } + + var centroids = (0 ..< levels).map { index in + turboQuantInterpolate( + grid: grid, + cdf: cdf, + quantile: (Double(index) + 0.5) / Double(levels) + ) + } + + for _ in 0 ..< 32 { + var boundaries = Array(repeating: 0.0, count: levels + 1) + boundaries[0] = -1.0 + boundaries[levels] = 1.0 + if levels > 1 { + for index in 1 ..< levels { + boundaries[index] = 0.5 * (centroids[index - 1] + centroids[index]) + } + } + + var next = centroids + for level in 0 ..< levels { + let start = boundaries[level] + let end = boundaries[level + 1] + var weightedSum = 0.0 + var totalWeight = 0.0 + + for gridIndex in 0 ..< grid.count { + let value = grid[gridIndex] + let inBucket = + if level == levels - 1 { + value >= start && value <= end + } else { + value >= start && value < end + } + guard inBucket else { continue } + let weight = weights[gridIndex] + weightedSum += weight * value + totalWeight += weight + } + + if totalWeight > 0 { + next[level] = weightedSum / totalWeight + } + } + + let delta = zip(next, centroids).map { abs($0 - $1) }.max() ?? 0 + centroids = next + if delta < 1e-6 { + break + } + } + + return centroids.map(Float.init) +} + +private func turboQuantCodebook(dim: Int, bits: Int) -> (MLXArray, [Float]) { + let cacheKey = "\(dim):\(bits)" + TurboQuantTableCache.lock.lock() + if let codebook = TurboQuantTableCache.codebooks[cacheKey], + let midpoints = TurboQuantTableCache.midpoints[cacheKey] + { + TurboQuantTableCache.lock.unlock() + return (codebook, midpoints) + } + TurboQuantTableCache.lock.unlock() + + let codebookValues = turboQuantBuildCodebook(dim: dim, bits: bits) + let codebook = MLXArray(codebookValues) + let midpoints: [Float] = + if codebookValues.count > 1 { + zip(codebookValues, codebookValues.dropFirst()).map { ($0 + $1) * 0.5 } + } else { + [] + } + + TurboQuantTableCache.lock.lock() + TurboQuantTableCache.codebooks[cacheKey] = codebook + TurboQuantTableCache.midpoints[cacheKey] = midpoints + TurboQuantTableCache.lock.unlock() + return (codebook, midpoints) +} + +private struct TurboQuantMSEState { + var norms: MLXArray + var indices: MLXArray +} + +private func turboQuantStateLength(_ state: TurboQuantMSEState?) -> Int { + state?.norms.dim(2) ?? 0 +} + +private func turboQuantSliceState(_ state: TurboQuantMSEState, end: Int) -> TurboQuantMSEState { + TurboQuantMSEState( + norms: state.norms[.ellipsis, .. TurboQuantMSEState { + TurboQuantMSEState( + norms: MLXArray.zeros([state.norms.dim(0), state.norms.dim(1), length], dtype: state.norms.dtype), + indices: MLXArray.zeros( + [state.indices.dim(0), state.indices.dim(1), length, state.indices.dim(3)], + dtype: state.indices.dtype) + ) +} + +private func turboQuantWriteState( + _ destination: inout TurboQuantMSEState, + _ source: TurboQuantMSEState, + start: Int +) { + let end = start + source.norms.dim(2) + destination.norms[.ellipsis, start ..< end] = source.norms + destination.indices[.ellipsis, start ..< end, 0...] = source.indices +} + +private func turboQuantReserveStateCapacity( + _ state: TurboQuantMSEState, + used: Int, + needed: Int, + step: Int +) -> TurboQuantMSEState { + let capacity = turboQuantStateLength(state) + guard capacity < needed else { return state } + + let newCapacity = ((needed + step - 1) / step) * step + var grown = turboQuantAllocateStateLike(state, length: newCapacity) + if used > 0 { + turboQuantWriteState(&grown, turboQuantSliceState(state, end: used), start: 0) + } + return grown +} + +private func turboQuantPackLowBit(_ values: MLXArray, bits: Int) -> MLXArray { + guard bits > 0 else { + var outputShape = Array(values.shape.dropLast()) + outputShape.append(0) + return MLXArray.zeros(outputShape, dtype: .uint32) + } + + let values = values.asType(.uint32) + let length = values.dim(-1) + let packedWidth = turboQuantPackedWidth(length: length, bits: bits) + let flat = values.reshaped(-1, length) + let packed = MLXArray.zeros([flat.dim(0), packedWidth], dtype: .uint32) + let mask = MLXArray(UInt32((1 << bits) - 1)) + + for index in 0 ..< length { + let bitOffset = index * bits + let wordIndex = bitOffset / 32 + let offset = bitOffset % 32 + let value = flat[.ellipsis, index].asType(.uint32) & mask + packed[.ellipsis, wordIndex] = packed[.ellipsis, wordIndex] | (value << offset) + + let spill = offset + bits - 32 + if spill > 0 && wordIndex + 1 < packedWidth { + packed[.ellipsis, wordIndex + 1] = + packed[.ellipsis, wordIndex + 1] | (value >> (bits - spill)) + } + } + + var outputShape = Array(values.shape.dropLast()) + outputShape.append(packedWidth) + return packed.reshaped(outputShape) +} + +private func turboQuantUnpackLowBit(_ packed: MLXArray, bits: Int, length: Int) -> MLXArray { + guard bits > 0 else { + var outputShape = Array(packed.shape.dropLast()) + outputShape.append(0) + return MLXArray.zeros(outputShape, dtype: .uint32) + } + + let packed = packed.asType(.uint32) + let flat = packed.reshaped(-1, packed.dim(-1)) + let unpacked = MLXArray.zeros([flat.dim(0), length], dtype: .uint32) + let mask = MLXArray(UInt32((1 << bits) - 1)) + + for index in 0 ..< length { + let bitOffset = index * bits + let wordIndex = bitOffset / 32 + let offset = bitOffset % 32 + var value = flat[.ellipsis, wordIndex] >> offset + let spill = offset + bits - 32 + if spill > 0 && wordIndex + 1 < flat.dim(1) { + value = value | (flat[.ellipsis, wordIndex + 1] << (bits - spill)) + } + unpacked[.ellipsis, index] = value & mask + } + + var outputShape = Array(packed.shape.dropLast()) + outputShape.append(length) + return unpacked.reshaped(outputShape) +} + +private struct TurboQuantMSECodec { + let dim: Int + let bits: Int + let useRHT: Bool + let signs: MLXArray? + let codebook: MLXArray + let midpoints: [Float] + + init(dim: Int, bits: Int, seed: Int) { + self.dim = dim + self.bits = bits + self.useRHT = turboQuantIsPowerOfTwo(dim) + self.signs = useRHT ? turboQuantSignVector(dim: dim, seed: seed) : nil + (self.codebook, self.midpoints) = turboQuantCodebook(dim: dim, bits: bits) + } + + private func rotateForward(_ array: MLXArray) -> MLXArray { + guard useRHT, let signs else { return array } + let scale = 1.0 / Float(Foundation.sqrt(Double(dim))) + return hadamardTransform(array * signs, scale: scale) + } + + private func rotateInverse(_ array: MLXArray) -> MLXArray { + guard useRHT, let signs else { return array } + let scale = 1.0 / Float(Foundation.sqrt(Double(dim))) + return hadamardTransform(array, scale: scale) * signs + } + + func quantize(_ vectors: MLXArray) -> TurboQuantMSEState { + let vectorsF32 = vectors.asType(.float32) + let norms = sqrt(sum(square(vectorsF32), axis: -1)) + + guard dim > 0, bits > 0 else { + return TurboQuantMSEState( + norms: norms.asType(.float16), + indices: MLXArray.zeros( + [vectors.dim(0), vectors.dim(1), vectors.dim(2), 0], + dtype: .uint32) + ) + } + + let safeNorms = maximum(norms, turboQuantEpsilon)[.ellipsis, .newAxis] + let unitVectors = vectorsF32 / safeNorms + let rotated = rotateForward(unitVectors) + + var indices = MLXArray.zeros(rotated.shape, dtype: .uint32) + for midpoint in midpoints { + indices = indices + (rotated .> midpoint).asType(.uint32) + } + + return TurboQuantMSEState( + norms: norms.asType(.float16), + indices: turboQuantPackLowBit(indices, bits: bits) + ) + } + + func dequantize(_ state: TurboQuantMSEState) -> MLXArray { + guard dim > 0, bits > 0 else { + return MLXArray.zeros( + [state.norms.dim(0), state.norms.dim(1), state.norms.dim(2), dim], + dtype: .float32) + } + + let unpacked = turboQuantUnpackLowBit(state.indices, bits: bits, length: dim).asType(.int32) + let rotated = take(codebook, unpacked, axis: 0) + let unitVectors = rotateInverse(rotated) + return state.norms[.ellipsis, .newAxis].asType(.float32) * unitVectors + } +} + /// Base cache implementation providing default behaviors open class BaseKVCache: KVCache { public var offset: Int = 0 @@ -1033,13 +1477,21 @@ public class RotatingKVCache: BaseKVCache, CustomDebugStringConvertible { } } -/// TurboQuant cache scaffold. +/// TurboQuant cache with packed MSE-style state and dense shadow buffers. /// -/// This preserves a distinct cache identity and metadata surface for TurboQuant while -/// keeping dense storage until the accelerated attention/update kernels land. +/// The packed norms/indices arrays are the serialized source of truth. Dense key/value +/// tensors are maintained only as runtime shadow state for the current fallback attention +/// path until the fused TurboQuant kernels land. public class TurboQuantKVCache: BaseKVCache, TurboQuantKVCacheProtocol, CustomDebugStringConvertible { - internal var keys: MLXArray? - internal var values: MLXArray? + private var keyState: TurboQuantMSEState? + private var valueState: TurboQuantMSEState? + private var shadowKeys: MLXArray? + private var shadowValues: MLXArray? + private var legacyDenseState: (keys: MLXArray, values: MLXArray)? + private var keyCodec: TurboQuantMSECodec? + private var valueCodec: TurboQuantMSECodec? + private var keyDimension: Int? + private var valueDimension: Int? public var configuration: TurboQuantConfiguration public var step = 256 public var didGrow = false @@ -1057,27 +1509,51 @@ public class TurboQuantKVCache: BaseKVCache, TurboQuantKVCacheProtocol, CustomDe return cache } let turboCache = TurboQuantKVCache(configuration: configuration) - if cache.offset > 0 { - turboCache.state = cache.state + if cache.offset > 0, cache.state.count == 2 { + let denseState = cache.state + turboCache.rebuildFromDenseState(keys: denseState[0], values: denseState[1]) } turboCache.offset = cache.offset return turboCache } public override func innerState() -> [MLXArray] { - [self.keys, self.values].compactMap { $0 } + [ + keyState?.norms, keyState?.indices, + valueState?.norms, valueState?.indices, + shadowKeys, shadowValues, + ].compactMap { $0 } + } + + private func ensureCodecs(keyDim: Int, valueDim: Int) { + let normalizedBits = validateTurboQuantBits(configuration.bits) + let keyBits = turboQuantKeyBits(for: normalizedBits) + let valueBits = turboQuantValueBits(for: normalizedBits) + + if keyCodec?.dim != keyDim || keyCodec?.bits != keyBits { + keyCodec = TurboQuantMSECodec(dim: keyDim, bits: keyBits, seed: turboQuantDefaultSeed) + } + if valueCodec?.dim != valueDim || valueCodec?.bits != valueBits { + valueCodec = TurboQuantMSECodec( + dim: valueDim, + bits: valueBits, + seed: turboQuantDefaultSeed + 1 + ) + } + + keyDimension = keyDim + valueDimension = valueDim } - public override func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { - let previous = self.offset + private func appendShadow(keys: MLXArray, values: MLXArray, previous: Int) { let newTokens = keys.dim(2) let needed = previous + newTokens let needsGrow = - if let currentKeys = self.keys, needed > currentKeys.dim(2) { + if let currentKeys = self.shadowKeys, needed > currentKeys.dim(2) { true } else { - self.keys == nil + self.shadowKeys == nil } if needsGrow { let B = keys.dim(0) @@ -1085,7 +1561,7 @@ public class TurboQuantKVCache: BaseKVCache, TurboQuantKVCacheProtocol, CustomDe let kHeadDim = keys.dim(3) let vHeadDim = values.dim(3) - let currentCapacity = self.keys?.dim(2) ?? 0 + let currentCapacity = self.shadowKeys?.dim(2) ?? 0 let newCapacity: Int if currentCapacity == 0 { let nSteps = (newTokens + step - 1) / step @@ -1104,79 +1580,247 @@ public class TurboQuantKVCache: BaseKVCache, TurboQuantKVCacheProtocol, CustomDe let newK = MLXArray.zeros(kShape, dtype: keys.dtype) let newV = MLXArray.zeros(vShape, dtype: values.dtype) - if var currentKeys = self.keys, var currentValues = self.values { + if var currentKeys = self.shadowKeys, var currentValues = self.shadowValues { if previous != currentCapacity { currentKeys = currentKeys[.ellipsis, .. 0 { currentValues = currentValues[.ellipsis, .. 0 { - self.values = concatenated([currentValues, newV], axis: 2) + self.shadowValues = concatenated([currentValues, newV], axis: 2) } else { - self.values = MLXArray.zeros( - [self.keys!.dim(0), currentValues.dim(1), self.keys!.dim(2), 0], + self.shadowValues = MLXArray.zeros( + [self.shadowKeys!.dim(0), currentValues.dim(1), self.shadowKeys!.dim(2), 0], dtype: currentValues.dtype) } } else { - self.keys = newK - self.values = newV + self.shadowKeys = newK + self.shadowValues = newV } didGrow = true } - self.offset += newTokens - self.keys?[.ellipsis, previous ..< self.offset, 0...] = keys + self.shadowKeys?[.ellipsis, previous ..< needed, 0...] = keys if values.dim(3) > 0 { - self.values?[.ellipsis, previous ..< self.offset, 0...] = values + self.shadowValues?[.ellipsis, previous ..< needed, 0...] = values } + } - let returnedKeys = self.keys![.ellipsis, .. 0 { - returnedValues = self.values![.ellipsis, .. 0 { + values + } else { + MLXArray.zeros([keys.dim(0), values.dim(1), offset, 0], dtype: values.dtype) + } + + if offset == 0 { + keyState = nil + valueState = nil + return + } + + keyState = keyCodec?.quantize(keys) + valueState = valueCodec?.quantize(values) + } + + private func rehydrateShadowFromPackedState() { + guard let keyState, let valueState, let keyDimension, let valueDimension else { return } + ensureCodecs(keyDim: keyDimension, valueDim: valueDimension) + shadowKeys = keyCodec?.dequantize(turboQuantSliceState(keyState, end: offset)) + shadowValues = valueCodec?.dequantize(turboQuantSliceState(valueState, end: offset)) + } + + private func denseState() -> (MLXArray, MLXArray) { + if let dense = legacyDenseState { + rebuildFromDenseState(keys: dense.keys, values: dense.values) + } + if shadowKeys == nil || shadowValues == nil { + rehydrateShadowFromPackedState() + } + + guard let shadowKeys, let shadowValues else { + let keyDim = keyDimension ?? 0 + let valueDim = valueDimension ?? 0 + return ( + MLXArray.zeros([1, 0, 0, keyDim], dtype: .float32), + MLXArray.zeros([1, 0, 0, valueDim], dtype: .float32) + ) + } + + let keys = shadowKeys[.ellipsis, .. 0 { + shadowValues[.ellipsis, .. (MLXArray, MLXArray) { + let previous = offset + ensureCodecs(keyDim: keys.dim(3), valueDim: values.dim(3)) + + let quantizedKeys = keyCodec!.quantize(keys) + let quantizedValues = valueCodec!.quantize(values) + + appendPackedState(current: &keyState, update: quantizedKeys, previous: previous) + appendPackedState(current: &valueState, update: quantizedValues, previous: previous) + appendShadow(keys: keys, values: values, previous: previous) + legacyDenseState = nil + offset += keys.dim(2) + + return denseState() + } + + private func fallbackAttention( + queries: MLXArray, + keys: MLXArray, + values: MLXArray, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode + ) -> MLXArray { + let (cachedKeys, cachedValues) = update(keys: keys, values: values) + return MLXFast.scaledDotProductAttention( + queries: queries, + keys: cachedKeys, + values: cachedValues, + scale: scale, + mask: mask + ) + } + + public func decodeAttention( + queries: MLXArray, + keys: MLXArray, + values: MLXArray, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode = .none + ) -> MLXArray { + fallbackAttention( + queries: queries, + keys: keys, + values: values, + scale: scale, + mask: mask + ) + } + + public func prefillAttention( + queries: MLXArray, + keys: MLXArray, + values: MLXArray, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode = .none + ) -> MLXArray { + fallbackAttention( + queries: queries, + keys: keys, + values: values, + scale: scale, + mask: mask + ) } public override var state: [MLXArray] { get { - guard let keys = self.keys, let values = self.values else { return [] } - if offset == keys.dim(2) { - return [keys, values] - } else { - let kSlice = keys[.ellipsis, .. 0 - ? values[.ellipsis, .. 0 { - values = v[.ellipsis, .. 0 { + self.shadowValues = shadowValues[.ellipsis, .. KVCacheSimple { let simpleCache = KVCacheSimple() - simpleCache.state = state + let (keys, values) = denseState() + simpleCache.state = [keys, values] return simpleCache } public var debugDescription: String { - "\(String(describing: Self.self)) \(Unmanaged.passUnretained(self).toOpaque()), offset: \(offset), bits: \(configuration.bits), variant: \(configuration.variant.rawValue), metadataPath: \(configuration.metadataPath ?? "-")" + "\(String(describing: Self.self)) \(Unmanaged.passUnretained(self).toOpaque()), offset: \(offset), bits: \(configuration.bits), variant: \(configuration.variant.rawValue), metadataPath: \(configuration.metadataPath ?? "-"), keyDim: \(keyDimension ?? -1), valueDim: \(valueDimension ?? -1)" } } diff --git a/Tests/MacLocalAPITests/TurboQuantCacheTests.swift b/Tests/MacLocalAPITests/TurboQuantCacheTests.swift index a8eba9df..20fda95d 100644 --- a/Tests/MacLocalAPITests/TurboQuantCacheTests.swift +++ b/Tests/MacLocalAPITests/TurboQuantCacheTests.swift @@ -1,11 +1,75 @@ import Foundation import MLX +import MLXFast import MLXLMCommon import Testing @testable import MacLocalAPI +@Suite(.serialized) struct TurboQuantCacheTests { + final class FakeTurboQuantCache: TurboQuantKVCacheProtocol { + var configuration = TurboQuantConfiguration(bits: 3.5) + var offset: Int = 0 + var maxSize: Int? = nil + var state: [MLXArray] = [] + var metaState: [String] = [] + var isTrimmable: Bool = true + private(set) var decodeCalls = 0 + private(set) var prefillCalls = 0 + + init() {} + + func innerState() -> [MLXArray] { + state + } + + func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { + offset += keys.dim(2) + state = [keys, values] + return (keys, values) + } + + @discardableResult + func trim(_ n: Int) -> Int { + let trimmed = min(offset, n) + offset -= trimmed + return trimmed + } + + func truncateToOffset() {} + + func makeMask( + n: Int, + windowSize: Int?, + returnArray: Bool + ) -> MLXFast.ScaledDotProductAttentionMaskMode { + .none + } + + func decodeAttention( + queries: MLXArray, + keys: MLXArray, + values: MLXArray, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode + ) -> MLXArray { + decodeCalls += 1 + return MLXArray.ones(queries.shape) * 7 + } + + func prefillAttention( + queries: MLXArray, + keys: MLXArray, + values: MLXArray, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode + ) -> MLXArray { + prefillCalls += 1 + return MLXArray.ones(queries.shape) * 9 + } + } + init() throws { try MLXMetalLibrary.ensureAvailable(verbose: false) } @@ -159,9 +223,18 @@ struct TurboQuantCacheTests { #expect(loaded.offset == 6) let state = loaded.state - #expect(state.count == 2) - #expect(state[0].shape == [1, 2, 6, 8]) - #expect(state[1].shape == [1, 2, 6, 8]) + #expect(state.count == 4) + #expect(state[0].shape == [1, 2, 6]) + #expect(state[1].shape == [1, 2, 6, 1]) + #expect(state[2].shape == [1, 2, 6]) + #expect(state[3].shape == [1, 2, 6, 1]) + + let dense = loaded.toUnquantized().state + #expect(dense.count == 2) + #expect(dense[0].shape == [1, 2, 6, 8]) + #expect(dense[1].shape == [1, 2, 6, 8]) + #expect(dense[0].allClose(MLXArray.ones([1, 2, 6, 8]), atol: 0.75).item(Bool.self)) + #expect(dense[1].allClose(MLXArray.ones([1, 2, 6, 8]) * 2, atol: 0.75).item(Bool.self)) } @Test("MlxCommand parses fractional KV bits and quant scheme") @@ -174,4 +247,46 @@ struct TurboQuantCacheTests { #expect(command.kvBits == 3.5) #expect(command.kvQuantScheme == "turboquant") } + + @Test("Attention utils route single-token decode through TurboQuant cache") + func attentionUtilsRouteDecode() { + let cache = FakeTurboQuantCache() + let queries = MLXArray.ones([1, 2, 1, 8]) + let keys = MLXArray.ones([1, 2, 1, 8]) + let values = MLXArray.ones([1, 2, 1, 8]) * 2 + + let output = attentionWithCacheUpdate( + queries: queries, + keys: keys, + values: values, + cache: cache, + scale: 1.0 + ) + + #expect(cache.decodeCalls == 1) + #expect(cache.prefillCalls == 0) + #expect(output.shape == [1, 2, 1, 8]) + #expect(output[0, 0, 0, 0].item(Float.self) == 7) + } + + @Test("Attention utils route prefill through TurboQuant cache") + func attentionUtilsRoutePrefill() { + let cache = FakeTurboQuantCache() + let queries = MLXArray.ones([1, 2, 3, 8]) + let keys = MLXArray.ones([1, 2, 3, 8]) + let values = MLXArray.ones([1, 2, 3, 8]) * 2 + + let output = attentionWithCacheUpdate( + queries: queries, + keys: keys, + values: values, + cache: cache, + scale: 1.0 + ) + + #expect(cache.decodeCalls == 0) + #expect(cache.prefillCalls == 1) + #expect(output.shape == [1, 2, 3, 8]) + #expect(output[0, 0, 0, 0].item(Float.self) == 9) + } } diff --git a/docs/feature-codex-turboquant-attention.md b/docs/feature-codex-turboquant-attention.md new file mode 100644 index 00000000..160eb9e3 --- /dev/null +++ b/docs/feature-codex-turboquant-attention.md @@ -0,0 +1,70 @@ +# TurboQuant Attention Dispatch + +**Branch:** `feature/codex-turboquant-attention` +**Status:** Implemented + +## Overview + +This branch promotes TurboQuant from a cache-type placeholder into a first-class attention execution path. Instead of relying on generic dense attention code to treat TurboQuant like a normal cache, the attention stack can now detect TurboQuant caches and delegate decode/prefill handling to TurboQuant-specific methods. + +## Scope + +- Add TurboQuant-specific attention protocol methods +- Route attention through TurboQuant caches in the shared attention helper +- Keep the implementation safe by allowing fallback behavior internally + +## Main Design + +### 1. Attention dispatch is explicit + +The core design change is that TurboQuant is no longer "just another cache" from the perspective of attention execution. The shared attention helper checks for a TurboQuant-capable cache and calls TurboQuant-specific methods: +- `decodeAttention(...)` +- `prefillAttention(...)` + +Implemented in: +- [AttentionUtils.swift](/Volumes/edata/codex/dev/git/apr3/maclocal-api/Scripts/patches/AttentionUtils.swift) + +This mirrors `mlx-vlm`'s model-base dispatch pattern and avoids hiding TurboQuant behind generic dense attention calls. + +### 2. Decode and prefill are distinct paths + +The branch makes single-token decode and multi-token prefill separate TurboQuant entry points. That matters because decode and prefill have very different optimization opportunities and should not be forced through the same implementation contract. + +Implemented in: +- [KVCache.swift](/Volumes/edata/codex/dev/git/apr3/maclocal-api/Scripts/patches/KVCache.swift) + +### 3. Fallback remains legal + +This branch does not claim a fully accelerated TurboQuant path yet. Instead, it makes TurboQuant dispatch explicit while allowing the underlying TurboQuant cache implementation to fall back to dense behavior until later branches land the real packed-state and Metal logic. + +That keeps correctness and architecture separate: +- this branch solves dispatch +- later branches solve representation and performance + +## Files + +- [Scripts/patches/AttentionUtils.swift](/Volumes/edata/codex/dev/git/apr3/maclocal-api/Scripts/patches/AttentionUtils.swift) +- [Scripts/patches/KVCache.swift](/Volumes/edata/codex/dev/git/apr3/maclocal-api/Scripts/patches/KVCache.swift) +- [Tests/MacLocalAPITests/TurboQuantCacheTests.swift](/Volumes/edata/codex/dev/git/apr3/maclocal-api/Tests/MacLocalAPITests/TurboQuantCacheTests.swift) + +## What This Branch Does Not Do + +- No packed TurboQuant state yet +- No inline Metal kernels yet +- No batch/concurrent TurboQuant batching support +- No prefix-cache hardening for mixed TurboQuant and non-TurboQuant formats + +## Why This Slice Exists + +Without an explicit attention dispatch hook, later optimization work would be forced into generic cache code or model-specific hacks. This branch creates the right architectural seam for later TurboQuant implementations. + +## Validation + +Validation checks that: +- single-token decode calls the TurboQuant decode path +- multi-token prefill calls the TurboQuant prefill path +- non-TurboQuant caches continue using their existing dispatch + +## Next Branch + +`feature/codex-turboquant-codecs` replaces the dense placeholder representation with packed TurboQuant state suitable for serialization and later fast-path execution. diff --git a/docs/feature-codex-turboquant-codecs.md b/docs/feature-codex-turboquant-codecs.md new file mode 100644 index 00000000..800e8181 --- /dev/null +++ b/docs/feature-codex-turboquant-codecs.md @@ -0,0 +1,73 @@ +# TurboQuant Packed Codec State + +**Branch:** `feature/codex-turboquant-codecs` +**Status:** Implemented + +## Overview + +This branch replaces the earlier dense-only TurboQuant placeholder with a packed TurboQuant cache representation. The important design change is that serialized TurboQuant state is no longer just a disguised dense KV tensor; it now stores norms and packed low-bit index state in a form that matches the intended TurboQuant runtime model more closely. + +## Scope + +- Introduce packed MSE-style TurboQuant state +- Split fractional bit-widths asymmetrically across keys and values +- Preserve prompt-cache serialization and restore behavior for packed state +- Keep dense shadow state only as a runtime helper + +## Main Design + +### 1. Serialized state becomes packed source-of-truth + +`TurboQuantKVCache` now stores: +- key norms +- key packed indices +- value norms +- value packed indices + +instead of persisting dense key/value tensors as its main state. + +Implemented in: +- [KVCache.swift](/Volumes/edata/codex/dev/git/apr3/maclocal-api/Scripts/patches/KVCache.swift) + +### 2. Fractional bits are split key/value asymmetrically + +Following the `mlx-vlm` direction, fractional TurboQuant bit-widths are interpreted as: +- keys use `floor(bits)` +- values use `ceil(bits)` + +This keeps both sides on simple integer low-bit codecs while letting values receive the higher effective precision where it tends to matter more. + +### 3. Dense shadow buffers remain runtime-only + +The branch intentionally keeps dense shadow keys/values in memory as a compatibility layer for the still-incomplete attention path. The packed state is the serialized truth; the dense state is just a runtime convenience until fused or semi-fused execution paths are fully landed. + +### 4. Prompt-cache restore is now meaningful for TurboQuant + +Because TurboQuant state is serialized in packed form, prompt-cache save/load now exercises a real TurboQuant round-trip rather than merely preserving a class tag around dense arrays. + +## Files + +- [Scripts/patches/KVCache.swift](/Volumes/edata/codex/dev/git/apr3/maclocal-api/Scripts/patches/KVCache.swift) +- [Tests/MacLocalAPITests/TurboQuantCacheTests.swift](/Volumes/edata/codex/dev/git/apr3/maclocal-api/Tests/MacLocalAPITests/TurboQuantCacheTests.swift) + +## What This Branch Does Not Do + +- No real Metal decode/prefill execution yet +- No value-side packed fast path +- No batch-aware TurboQuant cache classes +- No new public AFM request surface + +## Why This Slice Exists + +Attention fast paths and prompt-cache correctness both depend on a stable packed representation. This branch creates that representation before trying to optimize execution over it. + +## Validation + +Validation covers: +- prompt-cache save/load with TurboQuant identity preserved +- packed-state shape expectations +- dequantized round-trip back to dense tensors + +## Next Branch + +`feature/codex-turboquant-metal` adds the first real inline Metal-backed decode path on top of this packed-state representation.