Skip to content

Add MLX backend for Apple Silicon neural network inference#16

Draft
ChinChangYang wants to merge 5 commits intomasterfrom
feature/mlx-backend
Draft

Add MLX backend for Apple Silicon neural network inference#16
ChinChangYang wants to merge 5 commits intomasterfrom
feature/mlx-backend

Conversation

@ChinChangYang
Copy link
Copy Markdown
Owner

Summary

This PR adds a new MLX backend for KataGo, enabling neural network inference on Apple Silicon using Apple's MLX framework. The implementation includes full model support with several performance optimizations.

Key Features

  • Complete neural network support: Implements all required layers (convolution, batch normalization, residual blocks, policy/value heads, SGF metadata encoding)
  • Graph compilation: Uses mx::compile() to JIT-compile inference functions into fused Metal kernels, with thread-safe caching per configuration
  • Performance optimizations:
    • Skip mask operations when requireExactNNLen=true (all boards are exactly nnXLen × nnYLen)
    • Use logaddexp for Mish activation (reduces 7 ops to 2 ops)
    • Use addmm for fused matmul+bias operations (reduces 2 ops to 1 op)
  • FP16 support: Optional half-precision mode (defaults to FP32 as FP16 doesn't improve performance on Apple Silicon MLX)
  • Batch normalization: Computes merged scale/bias parameters for compatibility with layer tests

Build Instructions

cd cpp
cmake -G Ninja -DUSE_BACKEND=MLX
ninja

Requires CMake 3.27+

Configuration

Enable FP16 (optional, not recommended):

mlxUseFP16 = true

Testing

  • Passes ./katago runtests (unit tests)
  • Passes ./katago runnnlayertests (neural network layer tests)
  • Cross-backend validation via ./katago testgpuerror with Eigen reference
  • Benchmark testing shows functional parity with other backends

Changes

  • cpp/neuralnet/mlxbackend.cpp: New 1518-line implementation
  • cpp/CMakeLists.txt: MLX backend build configuration
  • cpp/configs/gtp_example.cfg: MLX configuration options
  • cpp/main.cpp, cpp/command/benchmark.cpp, cpp/program/setup.cpp: Backend registration

Performance Notes

  • Graph compilation reduces dispatch overhead and enables operation fusion
  • Compiled functions are cached per (batchSize, nnXLen, nnYLen, useMask, hasMeta)
  • FP16 mode available but not recommended (no performance benefit observed)

🤖 Generated with Claude Code

ChinChangYang and others added 5 commits January 15, 2026 21:12
Implements neural network inference using Apple's MLX framework:
- Add mlxbackend.cpp with full model support (conv, batchnorm, residual blocks, policy/value heads)
- Update CMakeLists.txt with MLX backend configuration (requires CMake 3.27+)
- Register MLX in backend prefixes and version info
- Compute merged batchnorm parameters for layer tests compatibility

Build with: cmake -G Ninja -DUSE_BACKEND=MLX && ninja

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use logaddexp for Mish activation (7 ops → 2 ops)
  logaddexp handles numerical stability internally, eliminating
  the need for manual clamping with minimum/where

- Use addmm for fused matmul+bias operations (2 ops → 1 op)
  Applied in SGFMetadataEncoder and ValueHead for reduced
  memory round-trips

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When all boards are exactly nnXLen x nnYLen, all mask values are 1,
so mask operations can be skipped for better performance. This follows
the same optimization pattern used in the CUDA backend.

Changes:
- Store requireExactNNLen in ComputeHandle
- Add useMask parameter to layer apply() methods
- Skip mask multiplication in BatchNormLayer when useMask=false
- Use direct max instead of mask-aware max in global pooling
- Skip trunk * mask multiplication when useMask=false
- Pre-compute fixed maskSum when requireExactNNLen=true

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Use mx::compile() to JIT-compile inference functions into fused Metal
kernels. The compiled functions are cached per configuration
(batchSize, nnXLen, nnYLen, useMask, hasMeta) for reuse.

Key changes:
- Add applyArrays() method for mx::array-based inference
- Add createCompiledFunc() to compile inference lambda
- Add applyCompiled() to use pre-compiled functions
- Add thread-safe compilation cache to ComputeHandle
- Modify NeuralNet::getOutput() to use compiled execution

The compilation reduces dispatch overhead and enables MLX to fuse
compatible operations (element-wise chains, matmul+bias, etc).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add optional FP16 (half precision) computation mode to MLX backend:
- All neural network layers now accept useFP16 parameter
- Weights are converted to FP16 at load time when enabled
- Inputs converted to FP16, outputs converted back to FP32
- Cache keys include FP16 mode to avoid mixing compiled functions

Default to FP32 since benchmarks show FP16 does not improve
performance on Apple Silicon MLX. Users can opt-in via
mlxUseFP16=true in config.

Also adds FP16 status reporting to benchmark command.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@ChinChangYang ChinChangYang marked this pull request as draft January 18, 2026 06:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant