Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/exo/worker/engines/mlx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def build(
self.tokenizer.tool_parser, # type: ignore
)

kv_prefix_cache = KVPrefixCache(self.group)
kv_prefix_cache = KVPrefixCache(self.group, model_id=self.model_id)

device_rank = 0 if self.group is None else self.group.rank()
if os.environ.get("EXO_NO_BATCH"):
Expand Down
531 changes: 527 additions & 4 deletions src/exo/worker/engines/mlx/cache.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions src/exo/worker/engines/mlx/disaggregated/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ def run_prefill_for_request(
if kv_prefix_cache is not None:
try:
cache_snapshots = [snapshot_ssm_states(cache)]
hit_ratio = prefix_hit_length / n_tokens if n_tokens > 0 else 0.0
if matched_index is not None and hit_ratio >= 0.5:
if kv_prefix_cache.should_update_entry(
matched_index, prefix_hit_length, min_prefix_hit_length=0
):
assert matched_index is not None
kv_prefix_cache.update_kv_cache(
matched_index,
prompt_tokens,
Expand Down
12 changes: 3 additions & 9 deletions src/exo/worker/engines/mlx/generator/batch_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
)
from exo.worker.runner.bootstrap import logger

_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
REMOTE_PREFILL_MIN_TOKENS = 1000


Expand Down Expand Up @@ -505,15 +504,10 @@ def _save_prefix_cache(
return

try:
hit_ratio = (
prefix_hit_length / len(all_prompt_tokens)
if len(all_prompt_tokens) > 0
else 0.0
)
if matched_index is not None and (
prefix_hit_length >= min_prefix_hit_length
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
if self.kv_prefix_cache.should_update_entry(
matched_index, prefix_hit_length, min_prefix_hit_length
):
assert matched_index is not None
self.kv_prefix_cache.update_kv_cache(
matched_index,
all_prompt_tokens,
Expand Down
13 changes: 3 additions & 10 deletions src/exo/worker/engines/mlx/generator/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@

generation_stream = mx.new_stream(mx.default_device())

_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5


@contextlib.contextmanager
def patch_embed_tokens(
Expand Down Expand Up @@ -679,15 +677,10 @@ def mlx_generate(
prefill_tps = kv_prefix_cache.prefill_tps[matched_index]

if kv_prefix_cache is not None:
hit_ratio = (
prefix_hit_length / len(all_prompt_tokens)
if len(all_prompt_tokens) > 0
else 0.0
)
if matched_index is not None and (
prefix_hit_length >= min_prefix_hit_length
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
if kv_prefix_cache.should_update_entry(
matched_index, prefix_hit_length, min_prefix_hit_length
):
assert matched_index is not None
kv_prefix_cache.update_kv_cache(
matched_index,
all_prompt_tokens,
Expand Down
19 changes: 18 additions & 1 deletion src/exo/worker/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import time
from dataclasses import dataclass
from enum import Enum
from typing import BinaryIO
from typing import TYPE_CHECKING, BinaryIO, cast

if TYPE_CHECKING:
from exo.worker.engines.mlx.cache import KVPrefixCache

from anyio import ClosedResourceError, EndOfStream

Expand Down Expand Up @@ -310,8 +313,17 @@ def handle_first_task(self, task: Task):
f"Received {task.__class__.__name__} outside of state machine in {self.current_status=}"
)

def _flush_kv_cache_to_disk(self) -> None:
"""Force-flush the hot KV slot to disk; no-op for engines without one."""
kv = cast(
"KVPrefixCache | None", getattr(self.generator, "kv_prefix_cache", None)
)
if kv is not None:
kv.flush_to_disk(force=True)

def shutdown(self, task: Task):
logger.info("runner shutting down")
self._flush_kv_cache_to_disk()
self.update_status(RunnerShuttingDown())
self.acknowledge_task(task)
self.generator.close()
Expand Down Expand Up @@ -380,6 +392,11 @@ def handle_generation_tasks(self, starting_task: GenerationTask):
f"Received {item.__class__.__name__} outside of state machine in {self.current_status=}"
)

# Generation queue drained — persist the hot KV slot while idle so a
# later crash doesn't lose it (otherwise it is only flushed on
# conversation switch).
self._flush_kv_cache_to_disk()

self.update_status(RunnerReady(prefill_server_port=self._prefill_server_port))
logger.info("runner ready")

Expand Down
Loading