Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 60 additions & 49 deletions src/exo/master/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, ModelId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
CustomModelCardAdded,
CustomModelCardDeleted,
Expand Down Expand Up @@ -136,6 +136,11 @@ def __init__(
self.state = State()
self._tg: TaskGroup = TaskGroup()
self.command_task_mapping: dict[CommandId, TaskId] = {}
# Tasks assigned this scheduling loop but not yet reflected in
# self.state (TaskCreated still round-tripping through the event
# router). Counted toward in-flight so a burst of concurrent requests
# doesn't all read a stale zero and stampede onto one instance.
self._pending_assignments: dict[TaskId, InstanceId] = {}
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
Expand Down Expand Up @@ -168,6 +173,45 @@ async def shutdown(self):
logger.info("Stopping Master")
self._tg.cancel_tasks()

def _in_flight_counts(
self, model_id: ModelId, exclude: frozenset[InstanceId] = frozenset()
) -> dict[InstanceId, int]:
"""In-flight task count per instance for a model.

Sums tasks already in self.state plus optimistic assignments made this
loop but not yet applied (their TaskCreated event is still round-tripping
through the event router). Without the optimistic part, concurrent
requests in one burst all read the same stale zero and the deterministic
tie-break sends them to the first instance.
"""
in_flight = {TaskStatus.Pending, TaskStatus.Running}
# Drop optimistic assignments now visible in state (they get counted via
# state.tasks below) or whose instance is gone, so they stop biasing.
for task_id in list(self._pending_assignments):
instance_id = self._pending_assignments[task_id]
if task_id in self.state.tasks or instance_id not in self.state.instances:
del self._pending_assignments[task_id]

counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id == model_id
and instance.instance_id not in exclude
):
state_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
and task.task_status in in_flight
)
pending_count = sum(
1
for assigned_id in self._pending_assignments.values()
if assigned_id == instance.instance_id
)
counts[instance.instance_id] = state_count + pending_count
return counts

async def _command_processor(self) -> None:
with self.command_receiver as commands:
async for forwarder_command in commands:
Expand All @@ -188,24 +232,10 @@ async def _command_processor(self) -> None:
for link in self.state.instance_links.values():
prefill_only.difference_update(link.decode_instances)

for instance in self.state.instances.values():
# NON-prefill-only instances matching the model ID
if (
instance.shard_assignments.model_id
== command.task_params.model
and instance.instance_id not in prefill_only
):
# count in-flight tasks of that instance
in_flight = {TaskStatus.Pending, TaskStatus.Running}
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
and task.task_status in in_flight
)
instance_task_counts[instance.instance_id] = (
task_count
)
instance_task_counts = self._in_flight_counts(
command.task_params.model,
exclude=frozenset(prefill_only),
)

# there are no NON-prefill-only instances matching this model ID
if not instance_task_counts:
Expand All @@ -222,6 +252,7 @@ async def _command_processor(self) -> None:

decode_instance_id = available_instance_ids[0]
task_id = TaskId()
self._pending_assignments[task_id] = decode_instance_id
params = command.task_params.model_copy(
update={
"prefill_endpoint": _prefill_endpoint_for(
Expand All @@ -243,21 +274,9 @@ async def _command_processor(self) -> None:
)
self.command_task_mapping[command.command_id] = task_id
case ImageGeneration():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.task_params.model
):
in_flight = {TaskStatus.Pending, TaskStatus.Running}
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
and task.task_status in in_flight
)
instance_task_counts[instance.instance_id] = (
task_count
)
instance_task_counts = self._in_flight_counts(
command.task_params.model
)

if not instance_task_counts:
raise ValueError(
Expand All @@ -273,6 +292,7 @@ async def _command_processor(self) -> None:

task_id = TaskId()
selected_instance_id = available_instance_ids[0]
self._pending_assignments[task_id] = selected_instance_id
generated_events.append(
TaskCreated(
task_id=task_id,
Expand All @@ -299,21 +319,9 @@ async def _command_processor(self) -> None:
)
self._expected_ranks[task_id] = ranks
case ImageEdits():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.task_params.model
):
in_flight = {TaskStatus.Pending, TaskStatus.Running}
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
and task.task_status in in_flight
)
instance_task_counts[instance.instance_id] = (
task_count
)
instance_task_counts = self._in_flight_counts(
command.task_params.model
)

if not instance_task_counts:
raise ValueError(
Expand All @@ -329,6 +337,7 @@ async def _command_processor(self) -> None:

task_id = TaskId()
selected_instance_id = available_instance_ids[0]
self._pending_assignments[task_id] = selected_instance_id
generated_events.append(
TaskCreated(
task_id=task_id,
Expand Down Expand Up @@ -406,6 +415,7 @@ async def _command_processor(self) -> None:
command.cancelled_command_id
)
) is not None:
self._pending_assignments.pop(task_id, None)
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
Expand All @@ -422,6 +432,7 @@ async def _command_processor(self) -> None:
command.finished_command_id, None
)
) is not None:
self._pending_assignments.pop(task_id, None)
generated_events.append(TaskDeleted(task_id=task_id))
else:
logger.warning(
Expand Down