Add MLX backend for Apple Silicon neural network inference#16
Draft
ChinChangYang wants to merge 5 commits intomasterfrom
Draft
Add MLX backend for Apple Silicon neural network inference#16ChinChangYang wants to merge 5 commits intomasterfrom
ChinChangYang wants to merge 5 commits intomasterfrom
Conversation
Implements neural network inference using Apple's MLX framework: - Add mlxbackend.cpp with full model support (conv, batchnorm, residual blocks, policy/value heads) - Update CMakeLists.txt with MLX backend configuration (requires CMake 3.27+) - Register MLX in backend prefixes and version info - Compute merged batchnorm parameters for layer tests compatibility Build with: cmake -G Ninja -DUSE_BACKEND=MLX && ninja Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use logaddexp for Mish activation (7 ops → 2 ops) logaddexp handles numerical stability internally, eliminating the need for manual clamping with minimum/where - Use addmm for fused matmul+bias operations (2 ops → 1 op) Applied in SGFMetadataEncoder and ValueHead for reduced memory round-trips Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When all boards are exactly nnXLen x nnYLen, all mask values are 1, so mask operations can be skipped for better performance. This follows the same optimization pattern used in the CUDA backend. Changes: - Store requireExactNNLen in ComputeHandle - Add useMask parameter to layer apply() methods - Skip mask multiplication in BatchNormLayer when useMask=false - Use direct max instead of mask-aware max in global pooling - Skip trunk * mask multiplication when useMask=false - Pre-compute fixed maskSum when requireExactNNLen=true Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Use mx::compile() to JIT-compile inference functions into fused Metal kernels. The compiled functions are cached per configuration (batchSize, nnXLen, nnYLen, useMask, hasMeta) for reuse. Key changes: - Add applyArrays() method for mx::array-based inference - Add createCompiledFunc() to compile inference lambda - Add applyCompiled() to use pre-compiled functions - Add thread-safe compilation cache to ComputeHandle - Modify NeuralNet::getOutput() to use compiled execution The compilation reduces dispatch overhead and enables MLX to fuse compatible operations (element-wise chains, matmul+bias, etc). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add optional FP16 (half precision) computation mode to MLX backend: - All neural network layers now accept useFP16 parameter - Weights are converted to FP16 at load time when enabled - Inputs converted to FP16, outputs converted back to FP32 - Cache keys include FP16 mode to avoid mixing compiled functions Default to FP32 since benchmarks show FP16 does not improve performance on Apple Silicon MLX. Users can opt-in via mlxUseFP16=true in config. Also adds FP16 status reporting to benchmark command. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds a new MLX backend for KataGo, enabling neural network inference on Apple Silicon using Apple's MLX framework. The implementation includes full model support with several performance optimizations.
Key Features
mx::compile()to JIT-compile inference functions into fused Metal kernels, with thread-safe caching per configurationrequireExactNNLen=true(all boards are exactly nnXLen × nnYLen)logaddexpfor Mish activation (reduces 7 ops to 2 ops)addmmfor fused matmul+bias operations (reduces 2 ops to 1 op)Build Instructions
cd cpp cmake -G Ninja -DUSE_BACKEND=MLX ninjaRequires CMake 3.27+
Configuration
Enable FP16 (optional, not recommended):
mlxUseFP16 = trueTesting
./katago runtests(unit tests)./katago runnnlayertests(neural network layer tests)./katago testgpuerrorwith Eigen referenceChanges
cpp/neuralnet/mlxbackend.cpp: New 1518-line implementationcpp/CMakeLists.txt: MLX backend build configurationcpp/configs/gtp_example.cfg: MLX configuration optionscpp/main.cpp,cpp/command/benchmark.cpp,cpp/program/setup.cpp: Backend registrationPerformance Notes
🤖 Generated with Claude Code