Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Scripts/apply-mlx-patches.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }

PATCH_FILES=("Qwen3VL.swift" "Qwen3Next.swift" "GatedDelta.swift" "Qwen3_5MoE.swift" "DeepseekV3.swift" "MiniMaxM2.swift" "NemotronH.swift" "GLM4MoeLite.swift" "GLM5MoeDsa.swift" "KimiK25.swift" "Gemma4Text.swift" "Gemma4VLM.swift" "LLMModelFactory.swift" "Load.swift" "Evaluate.swift" "LanguageModel.swift" "Tokenizer.swift" "Qwen3_5MoEVL.swift" "VLMModelFactory.swift" "SamplerTests.swift" "ToolCallFormat.swift" "KVCache.swift" "SwitchLayers.swift" "BatchKVCache.swift" "SSM.swift" "Chat.swift" "Gemma4FunctionParser.swift")
TARGET_PATHS=("Libraries/MLXVLM/Models/Qwen3VL.swift" "Libraries/MLXLLM/Models/Qwen3Next.swift" "Libraries/MLXLLM/Models/GatedDelta.swift" "Libraries/MLXLLM/Models/Qwen3_5MoE.swift" "Libraries/MLXLLM/Models/DeepseekV3.swift" "Libraries/MLXLLM/Models/MiniMaxM2.swift" "Libraries/MLXLLM/Models/NemotronH.swift" "Libraries/MLXLLM/Models/GLM4MoeLite.swift" "Libraries/MLXLLM/Models/GLM5MoeDsa.swift" "Libraries/MLXLLM/Models/KimiK25.swift" "Libraries/MLXLLM/Models/Gemma4Text.swift" "Libraries/MLXVLM/Models/Gemma4VLM.swift" "Libraries/MLXLLM/LLMModelFactory.swift" "Libraries/MLXLMCommon/Load.swift" "Libraries/MLXLMCommon/Evaluate.swift" "Libraries/MLXLMCommon/LanguageModel.swift" "Libraries/MLXLMCommon/Tokenizer.swift" "Libraries/MLXVLM/Models/Qwen3_5MoEVL.swift" "Libraries/MLXVLM/VLMModelFactory.swift" "Tests/MLXLMTests/SamplerTests.swift" "Libraries/MLXLMCommon/Tool/ToolCallFormat.swift" "Libraries/MLXLMCommon/KVCache.swift" "Libraries/MLXLMCommon/SwitchLayers.swift" "Libraries/MLXLMCommon/BatchKVCache.swift" "Libraries/MLXLLM/Models/SSM.swift" "Libraries/MLXLMCommon/Chat.swift" "Libraries/MLXLMCommon/Tool/Parsers/Gemma4FunctionParser.swift")
NEW_FILES=("Qwen3Next.swift" "GatedDelta.swift" "Qwen3_5MoE.swift" "MiniMaxM2.swift" "NemotronH.swift" "GLM4MoeLite.swift" "GLM5MoeDsa.swift" "KimiK25.swift" "Gemma4Text.swift" "Gemma4VLM.swift" "Qwen3_5MoEVL.swift" "SamplerTests.swift" "BatchKVCache.swift" "Gemma4FunctionParser.swift")
PATCH_FILES=("Qwen3VL.swift" "Qwen3Next.swift" "GatedDelta.swift" "Qwen3_5MoE.swift" "DeepseekV3.swift" "MiniMaxM2.swift" "NemotronH.swift" "GLM4MoeLite.swift" "GLM5MoeDsa.swift" "KimiK25.swift" "Gemma4Text.swift" "Gemma4VLM.swift" "LLMModelFactory.swift" "Load.swift" "Evaluate.swift" "LanguageModel.swift" "Tokenizer.swift" "AttentionUtils.swift" "Qwen3_5MoEVL.swift" "VLMModelFactory.swift" "SamplerTests.swift" "ToolCallFormat.swift" "KVCache.swift" "SwitchLayers.swift" "BatchKVCache.swift" "SSM.swift" "Chat.swift" "Gemma4FunctionParser.swift")
TARGET_PATHS=("Libraries/MLXVLM/Models/Qwen3VL.swift" "Libraries/MLXLLM/Models/Qwen3Next.swift" "Libraries/MLXLLM/Models/GatedDelta.swift" "Libraries/MLXLLM/Models/Qwen3_5MoE.swift" "Libraries/MLXLLM/Models/DeepseekV3.swift" "Libraries/MLXLLM/Models/MiniMaxM2.swift" "Libraries/MLXLLM/Models/NemotronH.swift" "Libraries/MLXLLM/Models/GLM4MoeLite.swift" "Libraries/MLXLLM/Models/GLM5MoeDsa.swift" "Libraries/MLXLLM/Models/KimiK25.swift" "Libraries/MLXLLM/Models/Gemma4Text.swift" "Libraries/MLXVLM/Models/Gemma4VLM.swift" "Libraries/MLXLLM/LLMModelFactory.swift" "Libraries/MLXLMCommon/Load.swift" "Libraries/MLXLMCommon/Evaluate.swift" "Libraries/MLXLMCommon/LanguageModel.swift" "Libraries/MLXLMCommon/Tokenizer.swift" "Libraries/MLXLMCommon/AttentionUtils.swift" "Libraries/MLXVLM/Models/Qwen3_5MoEVL.swift" "Libraries/MLXVLM/VLMModelFactory.swift" "Tests/MLXLMTests/SamplerTests.swift" "Libraries/MLXLMCommon/Tool/ToolCallFormat.swift" "Libraries/MLXLMCommon/KVCache.swift" "Libraries/MLXLMCommon/SwitchLayers.swift" "Libraries/MLXLMCommon/BatchKVCache.swift" "Libraries/MLXLLM/Models/SSM.swift" "Libraries/MLXLMCommon/Chat.swift" "Libraries/MLXLMCommon/Tool/Parsers/Gemma4FunctionParser.swift")
NEW_FILES=("Qwen3Next.swift" "GatedDelta.swift" "Qwen3_5MoE.swift" "MiniMaxM2.swift" "NemotronH.swift" "GLM4MoeLite.swift" "GLM5MoeDsa.swift" "KimiK25.swift" "Gemma4Text.swift" "Gemma4VLM.swift" "AttentionUtils.swift" "Qwen3_5MoEVL.swift" "SamplerTests.swift" "BatchKVCache.swift" "Gemma4FunctionParser.swift")

# --- Package-level patches (sed replacements in Package.swift) ---
# Each entry: "search_pattern|replacement"
Expand Down
95 changes: 95 additions & 0 deletions Scripts/patches/AttentionUtils.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import Foundation
import MLX

/// Attention utilities that match Python mlx-lm's interface
///
/// This provides a single function that automatically routes to TurboQuant, quantized, or regular
/// attention based on cache type, matching Python's `scaled_dot_product_attention`

/// Automatic attention with cache update
///
/// This function matches Python's `scaled_dot_product_attention` in base.py:
/// - Detects if cache is `TurboQuantKVCache` or `QuantizedKVCache` using `isinstance` pattern
/// - Routes to TurboQuant, quantized, or `MLXFast.scaledDotProductAttention`
/// - Handles cache updating automatically
/// - Transparent to models - they just call this function
///
/// **Usage in models:**
/// ```swift
/// let output = attentionWithCacheUpdate(
/// queries: queries,
/// keys: keys,
/// values: values,
/// cache: cache,
/// scale: scale,
/// mask: mask
/// )
/// ```
///
/// - Parameters:
/// - queries: Query tensor [B, nHeads, L, D]
/// - keys: Raw key tensor to be cached [B, nKVHeads, L, D]
/// - values: Raw value tensor to be cached [B, nKVHeads, L, D]
/// - cache: Cache instance (any type)
/// - scale: Attention scale factor
/// - mask: Attention mask
/// - Returns: Attention output [B, nHeads, L, D]
public func attentionWithCacheUpdate(
queries: MLXArray,
keys: MLXArray,
values: MLXArray,
cache: KVCache?,
scale: Float,
mask: MLXFast.ScaledDotProductAttentionMaskMode = .none
) -> MLXArray {
guard let cache else {
return MLXFast.scaledDotProductAttention(
queries: queries,
keys: keys,
values: values,
scale: scale,
mask: mask
)
}
if let turboQuantKVCache = cache as? TurboQuantKVCacheProtocol {
if queries.dim(2) == 1 {
return turboQuantKVCache.decodeAttention(
queries: queries,
keys: keys,
values: values,
scale: scale,
mask: mask
)
} else {
return turboQuantKVCache.prefillAttention(
queries: queries,
keys: keys,
values: values,
scale: scale,
mask: mask
)
}
} else if let quantizedKVCache = cache as? QuantizedKVCacheProtocol {
let (quantizedKeys, quantizedValues) = quantizedKVCache.updateQuantized(
keys: keys, values: values)
return quantizedScaledDotProductAttention(
queries: queries,
quantizedKeys: quantizedKeys,
quantizedValues: quantizedValues,
scale: scale,
mask: mask,
groupSize: quantizedKVCache.groupSize,
bits: quantizedKVCache.bits,
mode: quantizedKVCache.mode
)
} else {
let (cachedKeys, cachedValues) = cache.update(keys: keys, values: values)
return MLXFast.scaledDotProductAttention(
queries: queries,
keys: cachedKeys,
values: cachedValues,
scale: scale,
mask: mask
)
}
}
Loading