Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ private func milliseconds(since start: ContinuousClock.Instant) -> Double {

// MARK: - Constants

/// Maximum number of in-flight pipeline stages. Shared by the backpressure gate
/// and all buffer rotation logic to guarantee no two concurrent stages alias
/// the same memory.
private let pipelineDepth = 3
private let averageExpectedPromptSize = 256
private let temperatureTolerance: Double = 0.001

Expand All @@ -31,7 +35,7 @@ private let temperatureTolerance: Double = 0.001
/// Key features:
/// - Non-blocking GPU encoding via `InferenceFunction.encode`
/// - GPU-direct token sampling (argmax/topK) via MPSGraph compute shaders
/// - Double-buffered cache positions for CPU/GPU overlap
/// - Pipeline-depth-matched buffer rotation for CPU/GPU overlap
/// - Growing KV cache with pipelined expansion
/// - All tensors are owned MTLBuffers — Core AI never allocates/frees them
final class CoreAIPipelinedEngine: InferenceEngine, Sendable {
Expand Down Expand Up @@ -294,7 +298,7 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable {
/// sampler callback drains them (~70/s); depth grows until
/// `MPSCommandBufferImageCache` fails to allocate another private MTLBuffer.
///
/// Capacity 3 covers {logits encode + sampler commit + optional KV-cache grow};
/// Capacity matches `pipelineDepth` — covers {logits encode + sampler commit + optional KV-cache grow};
/// deeper queues only cost memory.
///
/// Class, not actor: `release()` runs synchronously from the Metal callback —
Expand Down Expand Up @@ -396,7 +400,9 @@ private struct EngineImpl: ~Copyable {

// Owned MTLBuffers
var inputTokensBuffer: MTLBuffer
var cachePositionBuffers: (MTLBuffer, MTLBuffer)
var cachePositionBuffers: [MTLBuffer]
var decodeOutputBuffers: [MTLBuffer]
var decodeLogitsBuffers: [MTLBuffer]

// KV cache — reuses CoreAIKVCache protocol from KVCache+CoreAI.swift
var kvCache: any CoreAIKVCache
Expand All @@ -413,8 +419,8 @@ private struct EngineImpl: ~Copyable {
var step: Int = 0

// Backpressure gate — see PipelineGate doc-comment for the failure mode it prevents.
// Capacity 3 covers {encode logits + sampler commit + optional KV-cache grow} in flight.
let inFlightGate = PipelineGate(capacity: 3)
// Capacity matches pipeline depth: {encode logits + sampler commit + optional KV-cache grow} in flight.
let inFlightGate = PipelineGate(capacity: pipelineDepth)

// MARK: - Init

Expand Down Expand Up @@ -486,22 +492,44 @@ private struct EngineImpl: ~Copyable {
throw InferenceRuntimeError.bufferAllocationFailed("inputTokens (\(inputTokensByteCount) bytes)")
}

// Allocate double-buffered cache positions
// Allocate pipeline-depth-matched cache position buffers
let cachePosSize = config.maxContextLength * posIdsDesc.scalarType.byteSize
guard let cachePosBuf0 = device.makeBuffer(length: cachePosSize, options: .storageModeShared),
let cachePosBuf1 = device.makeBuffer(length: cachePosSize, options: .storageModeShared)
else {
throw InferenceRuntimeError.bufferAllocationFailed("cachePositions (\(cachePosSize * 2) bytes)")
var cachePosBuffers: [MTLBuffer] = []
for _ in 0..<pipelineDepth {
guard let buf = device.makeBuffer(length: cachePosSize, options: .storageModeShared) else {
throw InferenceRuntimeError.bufferAllocationFailed("cachePositions (\(cachePosSize) bytes)")
}
cachePosBuffers.append(buf)
}

// Pre-populate cache positions with [0, 1, ..., maxCtx-1]
for buf in [cachePosBuf0, cachePosBuf1] {
for buf in cachePosBuffers {
let ptr = buf.contents().bindMemory(to: Int32.self, capacity: config.maxContextLength)
for i in 0..<config.maxContextLength {
ptr[i] = Int32(i)
}
}

// Allocate pipeline-depth-matched decode output buffers (sampler writes next token)
var decodeOutBuffers: [MTLBuffer] = []
for _ in 0..<pipelineDepth {
guard let buf = device.makeBuffer(length: MemoryLayout<Int32>.size, options: .storageModeShared) else {
throw InferenceRuntimeError.bufferAllocationFailed(
"decodeOutputBuffer (\(MemoryLayout<Int32>.size) bytes)")
}
decodeOutBuffers.append(buf)
}

// Allocate pipeline-depth-matched decode logits buffers (inference writes logits for decode)
let decodeLogitsSize = config.vocabSize * MemoryLayout<UInt16>.size
var decodeLogBufs: [MTLBuffer] = []
for _ in 0..<pipelineDepth {
guard let buf = device.makeBuffer(length: decodeLogitsSize, options: .storageModeShared) else {
throw InferenceRuntimeError.bufferAllocationFailed("decodeLogitsBuffer (\(decodeLogitsSize) bytes)")
}
decodeLogBufs.append(buf)
}

// Create KV cache using factory — pass original descriptors (with -1 dynamic dims intact)
// so the factory can correctly detect growing vs static support via isDynamicKVCache().
let kvCacheLocal = try KVCacheFactory.make(
Expand Down Expand Up @@ -556,7 +584,9 @@ private struct EngineImpl: ~Copyable {
self.positionIdsBaseDesc = posIdsDesc
self.logitsBaseDesc = logitsDesc
self.inputTokensBuffer = inputTokensBuf
self.cachePositionBuffers = (cachePosBuf0, cachePosBuf1)
self.cachePositionBuffers = cachePosBuffers
self.decodeOutputBuffers = decodeOutBuffers
self.decodeLogitsBuffers = decodeLogBufs
self.kvCache = kvCacheLocal
self.logits = logitsRef
self.cachedSampler = nil
Expand Down Expand Up @@ -603,7 +633,7 @@ private struct EngineImpl: ~Copyable {
///
/// 1. Construct RawView/MutableRawView from MTLBuffers with current shapes
/// 2. Encode to ComputeStream (non-blocking)
/// 3. withMetal3Queue: encode GPU argmax/topK (writes directly to inputTokensBuffer)
/// 3. withMetal3Queue: encode GPU argmax/topK (writes to rotating decodeOutputBuffers)
/// 4. Callback yields token
private mutating func _encodeNextStepGPU(
tokens: some Collection<Int32>,
Expand Down Expand Up @@ -632,7 +662,7 @@ private struct EngineImpl: ~Copyable {
// Prefill: write tokens at their natural position so this step's region is disjoint
// from any prior chunk's region still in-flight on the GPU (encode holds a live
// MTLBuffer reference; no encodeWriteOperands serialization available in Core AI).
// Decode: token is already at offset 0 via GPU-direct argmax write — no CPU write needed.
// Decode: token is in the previous step's decodeOutputBuffer — no CPU write needed.
let tokenByteOffset = processedTokenCount * MemoryLayout<Int32>.size
if !tokens.isEmpty {
let ptr = inputTokensBuffer.contents().bindMemory(
Expand All @@ -642,20 +672,33 @@ private struct EngineImpl: ~Copyable {
}
}

// Select cache position buffer for this step (double-buffered)
let cachePosBuffer = step % 2 == 0 ? cachePositionBuffers.0 : cachePositionBuffers.1
// Select cache position buffer for this step (pipeline-depth-matched rotation)
let cachePosBuffer = cachePositionBuffers[step % pipelineDepth]
let posLength = processedTokenCount + queryLength

// Build Inputs as AsyncValue (from MTLBuffers)
let tokenShape = [1, queryLength]
let tokenStrides = try resolvedStrides(descriptor: inputIdsBaseDesc, shape: tokenShape)
let tokenValue = unsafe InferenceFunction.AsyncValue(
unsafeBuffer: inputTokensBuffer,
byteOffset: tokens.isEmpty ? 0 : tokenByteOffset,
scalarType: .int32,
shape: tokenShape,
strides: tokenStrides
)
let tokenValue: InferenceFunction.AsyncValue
if tokens.isEmpty {
// Decode: read input token from previous step's decode output buffer
tokenValue = unsafe InferenceFunction.AsyncValue(
unsafeBuffer: decodeOutputBuffers[(step + pipelineDepth - 1) % pipelineDepth],
byteOffset: 0,
scalarType: .int32,
shape: tokenShape,
strides: tokenStrides
)
} else {
// Prefill: read from inputTokensBuffer at natural position
tokenValue = unsafe InferenceFunction.AsyncValue(
unsafeBuffer: inputTokensBuffer,
byteOffset: tokenByteOffset,
scalarType: .int32,
shape: tokenShape,
strides: tokenStrides
)
}
let posShape = [1, posLength]
let posStrides = try resolvedStrides(descriptor: positionIdsBaseDesc, shape: posShape)
let posValue = unsafe InferenceFunction.AsyncValue(
Expand Down Expand Up @@ -698,11 +741,12 @@ private struct EngineImpl: ~Copyable {
asyncStates.insert(&valState, for: valueCacheName)

// Build Output as AsyncMutableValue (logits)
let logitsBuffer = logits.metalBuffer
// Decode uses per-step rotating buffer; prefill uses the shared growing buffer.
let logitsOutputBuffer = tokens.isEmpty ? decodeLogitsBuffers[step % pipelineDepth] : logits.metalBuffer
let logitsShape = [1, queryLength, vocabSize]
let logitsStrides = try resolvedStrides(descriptor: logitsBaseDesc, shape: logitsShape)
var logitsOutput = unsafe InferenceFunction.AsyncMutableValue(
unsafeBuffer: logitsBuffer,
unsafeBuffer: logitsOutputBuffer,
byteOffset: 0,
scalarType: .float16,
shape: logitsShape,
Expand Down Expand Up @@ -731,7 +775,8 @@ private struct EngineImpl: ~Copyable {

// GPU sampling via Metal queue
let localGPUSampler = gpuSampler
let outputBuffer = inputTokensBuffer
let outputBuffer = decodeOutputBuffers[step % pipelineDepth]
let samplerLogitsBuffer = tokens.isEmpty ? decodeLogitsBuffers[step % pipelineDepth] : logits.metalBuffer
let logitsOffset = (actualTokenCount - 1) * vocabSize * MemoryLayout<UInt16>.size
let samplerStrategy = gpuSampler is MPSGraphArgmaxSampler ? "GPU-argmax" : "GPU-composite"
let samplerTemperature = cachedSamplerTemperature ?? 0.0
Expand All @@ -757,7 +802,7 @@ private struct EngineImpl: ~Copyable {
if queryLength == 1 {
localGPUSampler.encode(
to: queue,
logitsBuffer: logitsBuffer,
logitsBuffer: samplerLogitsBuffer,
logitsOffset: logitsOffset,
outputBuffer: outputBuffer,
outputOffset: 0,
Expand All @@ -766,7 +811,7 @@ private struct EngineImpl: ~Copyable {
} else {
localGPUSampler.encodeWithSlice(
to: queue,
logitsBuffer: logitsBuffer,
logitsBuffer: samplerLogitsBuffer,
queryLength: actualTokenCount,
outputBuffer: outputBuffer,
outputOffset: 0,
Expand Down Expand Up @@ -969,7 +1014,7 @@ private struct EngineImpl: ~Copyable {
ptr[processedTokenCount + i] = token
}

let cachePosBuffer = step % 2 == 0 ? cachePositionBuffers.0 : cachePositionBuffers.1
let cachePosBuffer = cachePositionBuffers[step % pipelineDepth]
let posLength = processedTokenCount + queryLength

// Build async values and encode
Expand Down Expand Up @@ -1076,7 +1121,7 @@ private struct EngineImpl: ~Copyable {
let ptr = inputTokensBuffer.contents().bindMemory(to: Int32.self, capacity: shape)
for i in 0..<shape { ptr[i] = 1 }

let cachePosBuffer = step % 2 == 0 ? cachePositionBuffers.0 : cachePositionBuffers.1
let cachePosBuffer = cachePositionBuffers[step % pipelineDepth]
let posLength = processedTokenCount + shape

let tShape = [1, shape]
Expand Down Expand Up @@ -1125,18 +1170,18 @@ private struct EngineImpl: ~Copyable {
to: computeStream
)

// Warm up argmax kernel
let logitsBuffer = logits.metalBuffer
let outputBuffer = inputTokensBuffer
// Warm up argmax kernel using pipeline-matched decode buffers
let warmupLogitsBuffer = decodeLogitsBuffers[step % pipelineDepth]
let warmupOutputBuffer = decodeOutputBuffers[step % pipelineDepth]
let logitsOffset = (shape - 1) * vocabSize * MemoryLayout<UInt16>.size

do {
let queue = pipelineQueue
warmupSampler.encode(
to: queue,
logitsBuffer: logitsBuffer,
logitsBuffer: warmupLogitsBuffer,
logitsOffset: logitsOffset,
outputBuffer: outputBuffer,
outputBuffer: warmupOutputBuffer,
outputOffset: 0,
completion: { _ in }
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Tokenizers
///
/// ## Engine Selection
/// The engine type is determined by `EngineFactory` based on model structure:
/// - **Pipelined**: GPU-accelerated with double buffering (fastest for GPU models)
/// - **Pipelined**: GPU-accelerated with pipeline-depth-matched buffering (fastest for GPU models)
/// - **Sequential**: CPU-based synchronous execution (fallback)
/// - **Static-shape**: Neural Engine optimized for chunked static models
///
Expand Down