diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 485ede30f7..6eabed1d3f 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -36,7 +36,7 @@ TestCommand, TextGeneration, ) -from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId +from exo.shared.types.common import CommandId, ModelId, NodeId, SessionId, SystemId from exo.shared.types.events import ( CustomModelCardAdded, CustomModelCardDeleted, @@ -136,6 +136,11 @@ def __init__( self.state = State() self._tg: TaskGroup = TaskGroup() self.command_task_mapping: dict[CommandId, TaskId] = {} + # Tasks assigned this scheduling loop but not yet reflected in + # self.state (TaskCreated still round-tripping through the event + # router). Counted toward in-flight so a burst of concurrent requests + # doesn't all read a stale zero and stampede onto one instance. + self._pending_assignments: dict[TaskId, InstanceId] = {} self.command_receiver = command_receiver self.local_event_receiver = local_event_receiver self.global_event_sender = global_event_sender @@ -168,6 +173,45 @@ async def shutdown(self): logger.info("Stopping Master") self._tg.cancel_tasks() + def _in_flight_counts( + self, model_id: ModelId, exclude: frozenset[InstanceId] = frozenset() + ) -> dict[InstanceId, int]: + """In-flight task count per instance for a model. + + Sums tasks already in self.state plus optimistic assignments made this + loop but not yet applied (their TaskCreated event is still round-tripping + through the event router). Without the optimistic part, concurrent + requests in one burst all read the same stale zero and the deterministic + tie-break sends them to the first instance. + """ + in_flight = {TaskStatus.Pending, TaskStatus.Running} + # Drop optimistic assignments now visible in state (they get counted via + # state.tasks below) or whose instance is gone, so they stop biasing. + for task_id in list(self._pending_assignments): + instance_id = self._pending_assignments[task_id] + if task_id in self.state.tasks or instance_id not in self.state.instances: + del self._pending_assignments[task_id] + + counts: dict[InstanceId, int] = {} + for instance in self.state.instances.values(): + if ( + instance.shard_assignments.model_id == model_id + and instance.instance_id not in exclude + ): + state_count = sum( + 1 + for task in self.state.tasks.values() + if task.instance_id == instance.instance_id + and task.task_status in in_flight + ) + pending_count = sum( + 1 + for assigned_id in self._pending_assignments.values() + if assigned_id == instance.instance_id + ) + counts[instance.instance_id] = state_count + pending_count + return counts + async def _command_processor(self) -> None: with self.command_receiver as commands: async for forwarder_command in commands: @@ -188,24 +232,10 @@ async def _command_processor(self) -> None: for link in self.state.instance_links.values(): prefill_only.difference_update(link.decode_instances) - for instance in self.state.instances.values(): - # NON-prefill-only instances matching the model ID - if ( - instance.shard_assignments.model_id - == command.task_params.model - and instance.instance_id not in prefill_only - ): - # count in-flight tasks of that instance - in_flight = {TaskStatus.Pending, TaskStatus.Running} - task_count = sum( - 1 - for task in self.state.tasks.values() - if task.instance_id == instance.instance_id - and task.task_status in in_flight - ) - instance_task_counts[instance.instance_id] = ( - task_count - ) + instance_task_counts = self._in_flight_counts( + command.task_params.model, + exclude=frozenset(prefill_only), + ) # there are no NON-prefill-only instances matching this model ID if not instance_task_counts: @@ -222,6 +252,7 @@ async def _command_processor(self) -> None: decode_instance_id = available_instance_ids[0] task_id = TaskId() + self._pending_assignments[task_id] = decode_instance_id params = command.task_params.model_copy( update={ "prefill_endpoint": _prefill_endpoint_for( @@ -243,21 +274,9 @@ async def _command_processor(self) -> None: ) self.command_task_mapping[command.command_id] = task_id case ImageGeneration(): - for instance in self.state.instances.values(): - if ( - instance.shard_assignments.model_id - == command.task_params.model - ): - in_flight = {TaskStatus.Pending, TaskStatus.Running} - task_count = sum( - 1 - for task in self.state.tasks.values() - if task.instance_id == instance.instance_id - and task.task_status in in_flight - ) - instance_task_counts[instance.instance_id] = ( - task_count - ) + instance_task_counts = self._in_flight_counts( + command.task_params.model + ) if not instance_task_counts: raise ValueError( @@ -273,6 +292,7 @@ async def _command_processor(self) -> None: task_id = TaskId() selected_instance_id = available_instance_ids[0] + self._pending_assignments[task_id] = selected_instance_id generated_events.append( TaskCreated( task_id=task_id, @@ -299,21 +319,9 @@ async def _command_processor(self) -> None: ) self._expected_ranks[task_id] = ranks case ImageEdits(): - for instance in self.state.instances.values(): - if ( - instance.shard_assignments.model_id - == command.task_params.model - ): - in_flight = {TaskStatus.Pending, TaskStatus.Running} - task_count = sum( - 1 - for task in self.state.tasks.values() - if task.instance_id == instance.instance_id - and task.task_status in in_flight - ) - instance_task_counts[instance.instance_id] = ( - task_count - ) + instance_task_counts = self._in_flight_counts( + command.task_params.model + ) if not instance_task_counts: raise ValueError( @@ -329,6 +337,7 @@ async def _command_processor(self) -> None: task_id = TaskId() selected_instance_id = available_instance_ids[0] + self._pending_assignments[task_id] = selected_instance_id generated_events.append( TaskCreated( task_id=task_id, @@ -406,6 +415,7 @@ async def _command_processor(self) -> None: command.cancelled_command_id ) ) is not None: + self._pending_assignments.pop(task_id, None) generated_events.append( TaskStatusUpdated( task_status=TaskStatus.Cancelled, @@ -422,6 +432,7 @@ async def _command_processor(self) -> None: command.finished_command_id, None ) ) is not None: + self._pending_assignments.pop(task_id, None) generated_events.append(TaskDeleted(task_id=task_id)) else: logger.warning(