diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 49068366ba8..50f9baf30a9 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1965,7 +1965,7 @@ def check(self): """ check the legality of config """ - assert self.scheduler_config.max_num_seqs <= 256, ( + assert self.scheduler_config.max_num_seqs <= 512, ( "The parameter `max_num_seqs` is not allowed to exceed 256, " f"but now it's {self.scheduler_config.max_num_seqs}." ) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index e4ee65b2d5c..ea4c6ad548b 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -222,6 +222,11 @@ class EngineArgs: """ Maximum number of tokens to batch together. """ + chunked_prefill_size: Optional[int] = None + """ + The maximum number of tokens in a prefill batch. Similar to SGLang's chunked_prefill_size. + Will be dynamically calculated based on GPU memory if not specified. + """ kv_cache_ratio: float = 0.75 """ Ratio of tokens to process in a block. @@ -444,6 +449,10 @@ class EngineArgs: """ Flag to enable overlapping schedule. Default is False (disabled). """ + enable_priority_scheduling: bool = False + """ + Flag to enable priority scheduling with preemption. + """ graph_optimization_config: Optional[Dict[str, Any]] = None """ Configuration for graph optimization backend execution. @@ -980,6 +989,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.max_num_batched_tokens, help="Maximum number of tokens to batch together.", ) + parallel_group.add_argument( + "--chunked-prefill-size", + type=int, + default=EngineArgs.chunked_prefill_size, + help="The maximum number of tokens in a prefill batch. " + "If not specified, will be dynamically calculated based on GPU memory.", + ) parallel_group.add_argument( "--gpu-memory-utilization", type=float, @@ -1323,6 +1339,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.enable_overlap_schedule, help="Enable overlapping schedule.", ) + scheduler_group.add_argument( + "--enable-priority-scheduling", + action="store_true", + default=EngineArgs.enable_priority_scheduling, + help="Enable priority scheduling with preemption. ", + ) return parser @@ -1439,13 +1461,40 @@ def create_engine_config(self) -> FDConfig: if current_platform.is_maca(): self.max_num_batched_tokens = self.max_model_len else: - self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM + self.max_num_batched_tokens = 16384 # if set to max_model_len, it's easy to be OOM else: if self.enable_chunked_prefill: self.max_num_batched_tokens = 2048 else: self.max_num_batched_tokens = self.max_model_len + # SGLang-aligned: dynamically calculate chunked_prefill_size based on GPU memory + if self.chunked_prefill_size is None and int(envs.ENABLE_V1_KVCACHE_SCHEDULER): + if current_platform.is_cuda(): + try: + # Use paddle (like other FD code) instead of torch + import paddle + gpu_mem = paddle.device.cuda.get_device_properties(0).total_memory / (1024 ** 3) # GB + # SGLang's chunked_prefill_size logic: + if gpu_mem < 20: + self.chunked_prefill_size = 2048 + elif gpu_mem < 35: + self.chunked_prefill_size = 2048 + elif gpu_mem < 60: + self.chunked_prefill_size = 4096 + elif gpu_mem < 160: + self.chunked_prefill_size = 8192 # H100, H20, H200, 140GB + else: + self.chunked_prefill_size = 16384 # B200, MI300 + console_logger.info(f"Auto-detected GPU memory: {gpu_mem:.1f}GB, chunked_prefill_size: {self.chunked_prefill_size}") + except Exception as e: + console_logger.warning(f"Failed to detect GPU memory, using default chunked_prefill_size: {e}") + self.chunked_prefill_size = 8192 # default for 140GB + elif current_platform.is_maca(): + self.chunked_prefill_size = self.max_model_len + else: + self.chunked_prefill_size = 8192 # default for non-CUDA platforms + all_dict = asdict(self) all_dict["model_cfg"] = model_cfg cache_cfg = CacheConfig(all_dict) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 8e4cc9d06b8..7e2d7b38640 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -202,18 +202,28 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l self.bos_client = None self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4) - self.init_reserve_output_block_num = ( - envs.FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # int - self.decay_output_block_num = ( - envs.FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # float - self.min_reserve_output_block_num = ( - envs.FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # int - self.current_reserve_output_block_num = self.init_reserve_output_block_num - self.current_reserve_output_block_num_float = self.init_reserve_output_block_num - self.can_relax_prefill_strategy = True + # New token ratio mechanism for dynamic decode reservation (inspired by SGLang) + # This replaces the fixed per-request block reservation with a ratio-based approach + schedule_conservativeness = 1.0 # Can be made configurable via SchedulerConfig later + self.init_new_token_ratio = min( + envs.FD_INIT_NEW_TOKEN_RATIO * schedule_conservativeness, + 1.0, + ) + self.min_new_token_ratio = min( + self.init_new_token_ratio * envs.FD_MIN_NEW_TOKEN_RATIO_FACTOR, + 1.0, + ) + self.new_token_ratio_decay = ( + self.init_new_token_ratio - self.min_new_token_ratio + ) / envs.FD_NEW_TOKEN_RATIO_DECAY_STEPS + self.current_new_token_ratio = self.init_new_token_ratio + self.clip_max_new_tokens_estimation = envs.FD_CLIP_MAX_NEW_TOKENS_ESTIMATION + + llm_logger.info( + f"NewTokenRatio initialized: init={self.init_new_token_ratio:.3f}, " + f"min={self.min_new_token_ratio:.3f}, decay_per_step={self.new_token_ratio_decay:.6f}, " + f"clip_max_tokens={self.clip_max_new_tokens_estimation}" + ) def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -246,8 +256,8 @@ def reschedule_preempt_task(self, request_id, process_func=None): request = self.requests[request_id] if process_func is not None: process_func(request) - llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}") - self.waiting.appendleft(request) + llm_logger.debug(f"self.waiting append request to end:{request.request_id},req.type:{request.status}") + self.waiting.append(request) # Append to end of queue (FIFO order) self.to_be_rescheduled_request_id_set.remove(request_id) def _info_each_block(self): @@ -303,64 +313,308 @@ def wait_worker_inflight_requests_finish(self, timeout=60): def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs): """ - If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out. + If the request cannot be scheduled, preempt running decode requests one by one until it can be scheduled. + Only preempt decode requests (num_computed_tokens >= need_prefill_tokens). + + SGLang-aligned strategy: + - Sort by (output_len asc, input_len desc) to prioritize retracting short-output, long-input requests + - Keep at least 1 request in running + - After preemption, update current_new_token_ratio based on remaining requests """ - can_schedule = False - while self._can_preempt(): - if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): - preempted_req = self.running.pop() - if preempted_req.use_extend_tables: - self.running.insert(0, preempted_req) - continue - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - if self.config.scheduler_config.splitwise_role == "decode": - self.tasks_list[preempted_req.idx] = None - self.stop_flags[preempted_req.idx] = True - if preempted_req.request_id in self.requests: - del self.requests[preempted_req.request_id] - if preempted_req.request_id in self.req_dict: - del self.req_dict[preempted_req.request_id] - self._free_blocks(preempted_req) - llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}") - else: - self._free_blocks(preempted_req) - preempted_req.num_cached_blocks = 0 - self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) - llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}") - preempted_reqs.append(preempted_req) - scheduled_reqs.append(self._prepare_preempt_task(preempted_req)) + # Collect decode requests (num_computed_tokens >= need_prefill_tokens) + decode_requests = [ + req for req in self.running + if req.num_computed_tokens >= req.need_prefill_tokens + ] + + # SGLang-aligned sort: prioritize retracting requests with shorter output + # If output_len is equal, prioritize retracting requests with longer input + decode_requests.sort( + key=lambda r: (len(r.output_token_ids), -r.prompt_token_ids_len), + reverse=True, # pop from end: shorter output first + ) - llm_logger.debug( - f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}" - ) - llm_logger.debug(self.info()) - self._info_each_block() + # If only 1 decode request, cannot preempt (need to keep at least 1) + # Return False to let scheduler handle this gracefully + if len(decode_requests) <= 1: + can_schedule = self.cache_manager.can_allocate_gpu_blocks(num_new_blocks) + return can_schedule - if preempted_req == request: - # No more request to preempt. - can_schedule = False - break - else: - # The request can be scheduled. - can_schedule = True + preempted_count = 0 + remaining_req_count = len(decode_requests) - 1 # Count for KV eviction (decreases after each preempt) + + for preempted_req in decode_requests[:-1]: # Skip last one to keep at least 1 + if self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): break - self.current_reserve_output_block_num = self.init_reserve_output_block_num - self.current_reserve_output_block_num_float = self.init_reserve_output_block_num - self.can_relax_prefill_strategy = False + + # Remove from running list + self.running.remove(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + + # Mark as retracted for SGLang alignment + preempted_req.is_retracted = True + + if self.config.scheduler_config.splitwise_role == "decode": + self.tasks_list[preempted_req.idx] = None + self.stop_flags[preempted_req.idx] = True + if preempted_req.request_id in self.requests: + del self.requests[preempted_req.request_id] + if preempted_req.request_id in self.req_dict: + del self.req_dict[preempted_req.request_id] + self._free_blocks(preempted_req) + llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}") + else: + self._free_blocks(preempted_req) + preempted_req.num_cached_blocks = 0 + self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) + llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}") + + preempted_reqs.append(preempted_req) + scheduled_reqs.append(self._prepare_preempt_task(preempted_req)) + preempted_count += 1 + + # Evict KV cache from tree (SGLang-aligned: retract_decode_steps * remaining_req_count) + self._evict_decode_kv_cache(remaining_req_count) + remaining_req_count -= 1 + + llm_logger.debug( + f"preempt {preempted_req.request_id} in idx {preempted_req.idx} " + f"with output_len={len(preempted_req.output_token_ids)}, " + f"input_len={preempted_req.prompt_token_ids_len}" + ) + + if preempted_count > 0: + llm_logger.debug(self.info()) + self._info_each_block() + + # Update new_token_ratio based on remaining requests (SGLang style) + self._update_new_token_ratio_after_preemption() + + # Check if we can schedule now + can_schedule = self.cache_manager.can_allocate_gpu_blocks(num_new_blocks) return can_schedule - def _get_can_schedule_prefill_threshold_block(self, request, num_chunk_new_block): - if self.can_relax_prefill_strategy: - can_schedule_block_num_threshold = num_chunk_new_block + def _evict_decode_kv_cache(self, remaining_req_count: int): + """ + Evict KV cache from tree cache after retracting a decode request. + + SGLang-aligned strategy: + - Each retraction triggers eviction of retract_decode_steps * remaining_req_count tokens + - This frees up space for new requests immediately + + Args: + remaining_req_count: Number of requests that will remain in running after eviction + """ + # Default retract_decode_steps = 20 (SGLang default) + retract_decode_steps = getattr(self, 'retract_decode_steps', 20) + + num_tokens_to_evict = remaining_req_count * retract_decode_steps + + llm_logger.debug( + f"Evicting {num_tokens_to_evict} KV cache tokens " + f"(retract_decode_steps={retract_decode_steps}, remaining_req_count={remaining_req_count})" + ) + + # Convert tokens to blocks for FD's cache manager + # FD's PrefixCacheManager uses block-based eviction + if self.cache_manager is not None: + block_size = self.config.cache_config.block_size + num_blocks_to_evict = (num_tokens_to_evict + block_size - 1) // block_size + + llm_logger.debug( + f"Evicting {num_blocks_to_evict} blocks " + f"(={num_tokens_to_evict} tokens) from GPU cache " + f"(block_size={block_size})" + ) + + # Use FD's existing cache eviction API + # free_block_ids (sync version) will handle the LRU eviction logic + # including GPU -> CPU swap and GPU -> Storage persistence + # Use sync version to ensure blocks are freed before checking can_allocate + self.cache_manager.free_block_ids(num_blocks_to_evict) else: - can_schedule_block_num_threshold = ( - request.need_prefill_tokens + self.config.cache_config.block_size - 1 - ) // self.config.cache_config.block_size + len(self.running) * self.current_reserve_output_block_num - if self.config.speculative_config.method is not None: - can_schedule_block_num_threshold = min( - can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq + llm_logger.warning("cache_manager is None, cannot evict KV cache") + + def _update_new_token_ratio_after_preemption(self): + """ + Update current_new_token_ratio based on remaining running decode requests. + Mimics SGLang's logic: + new_ratio = (total_decoded_tokens + retract_decode_steps * num_decode_reqs) / (total_max_new_tokens + 1) + + Note: Only count decode requests (num_computed_tokens >= need_prefill_tokens), + not prefill requests. + + SGLang-aligned behavior: + - If no decode requests running, keep the current ratio unchanged + (matching SGLang: ratio remains continuous when only prefill requests are running) + - Ratio reset only happens when the system is completely idle + (no prefill and no decode requests) + """ + # Filter to only decode requests (matching SGLang's self.reqs) + decode_reqs = [ + req for req in self.running + if req.num_computed_tokens >= req.need_prefill_tokens + ] + + if len(decode_reqs) == 0: + # No decode requests running, keep current ratio unchanged + # (matching SGLang: ratio remains continuous when only prefill requests are running) + llm_logger.debug( + f"No decode requests after preemption, keeping current_new_token_ratio={self.current_new_token_ratio:.3f}" + ) + return + + total_decoded_tokens = 0 + total_max_new_tokens = 0 + + for req in decode_reqs: + # Count decoded tokens + already_decoded = len(req.output_token_ids) + total_decoded_tokens += already_decoded + + # Get max_new_tokens for this request + if req.sampling_params and req.sampling_params.max_tokens is not None: + max_new_tokens = req.sampling_params.max_tokens + else: + max_new_tokens = self.config.model_config.max_model_len - req.need_prefill_tokens + total_max_new_tokens += max_new_tokens + + # SGLang's formula with retract_decode_steps (SGLang default: 20) + retract_decode_steps = getattr(self, 'retract_decode_steps', 20) + num_decode_reqs = len(decode_reqs) + + new_ratio = ( + total_decoded_tokens + retract_decode_steps * num_decode_reqs + ) / (total_max_new_tokens + 1) + + # Clamp to [min_ratio, init_ratio] + new_ratio = max(self.min_new_token_ratio, min(self.init_new_token_ratio, new_ratio)) + + llm_logger.debug( + f"Update new_token_ratio after preemption: " + f"decode_reqs={num_decode_reqs}, decoded={total_decoded_tokens}, " + f"max_new={total_max_new_tokens}, ratio={new_ratio:.3f} " + f"(was {self.current_new_token_ratio:.3f})" + ) + + self.current_new_token_ratio = new_ratio + + def reset_new_token_ratio_on_idle(self): + """ + Reset new_token_ratio to init_new_token_ratio when system is completely idle. + + SGLang alignment: Only reset when both running and waiting queues are empty. + This mimics SGLang's self_check_during_idle behavior. + """ + if len(self.running) == 0 and len(self.waiting) == 0: + if self.current_new_token_ratio != self.init_new_token_ratio: + llm_logger.debug( + f"System completely idle, resetting new_token_ratio " + f"from {self.current_new_token_ratio:.3f} to {self.init_new_token_ratio:.3f}" ) + self.current_new_token_ratio = self.init_new_token_ratio + + def _calculate_decode_reserved_tokens_by_ratio(self): + """ + Calculate total reserved tokens for all running decode requests based on current_new_token_ratio. + + For each request in decode phase, calculate: + remaining_tokens = min(max_new_tokens - already_decoded, clip_estimation) + reserved_tokens = remaining_tokens * current_new_token_ratio + + Returns: + int: Total number of tokens to reserve for decode requests + """ + total_reserved_tokens = 0 + num_decode_reqs = 0 + + for req in self.running: + # Only calculate reservation for requests in decode phase + if req.num_computed_tokens < req.need_prefill_tokens: + continue # Still in prefill, skip + + num_decode_reqs += 1 + + # Get max_new_tokens for this request + if req.sampling_params and req.sampling_params.max_tokens is not None: + max_new_tokens = req.sampling_params.max_tokens + else: + # Fallback: use max_model_len - prompt_len + max_new_tokens = self.config.model_config.max_model_len - req.need_prefill_tokens + + # Calculate remaining tokens to generate + already_decoded = len(req.output_token_ids) + remaining_tokens = max(0, max_new_tokens - already_decoded) + + # Clip to reasonable upper bound to avoid single long request dominating budget + remaining_tokens = min(remaining_tokens, self.clip_max_new_tokens_estimation) + + # Calculate reservation based on current ratio (no rounding here, keep precision) + reserved_tokens = remaining_tokens * self.current_new_token_ratio + total_reserved_tokens += reserved_tokens + + llm_logger.debug( + f"Decode reservation: {num_decode_reqs} decode reqs, " + f"{total_reserved_tokens:.1f} tokens, ratio={self.current_new_token_ratio:.3f}" + ) + + return total_reserved_tokens + + def _get_can_schedule_prefill_threshold_block(self, request, num_chunk_new_block): + """ + Calculate the total tokens needed for scheduling a new prefill request. + + Total tokens includes: + 1. Tokens needed for current prefill + 2. Tokens reserved for this request's future decode (max_new_tokens only for last chunk) + 3. Tokens reserved for all existing decode requests (ratio-based) + + SGLang-aligned behavior: + - Only reserve max_new_tokens when this is the LAST chunk of prefill + - For non-last chunks, reserve 0 (they will reserve on final chunk) + + Returns: + int: Total blocks needed (ceiled once at the end) to safely admit this request + """ + # 1. SGLang-aligned: Use current chunk's token count, not the full prefill + # This is the key difference - only reserve for what we're actually processing NOW + required_tokens_for_prefill = num_chunk_new_block * self.config.cache_config.block_size + + # 2. SGLang-aligned: Only reserve max_new_tokens for the LAST chunk + # Calculate remaining tokens to prefill after this chunk + remaining_tokens_to_prefill = request.need_prefill_tokens - request.num_computed_tokens + is_last_chunk = remaining_tokens_to_prefill <= num_chunk_new_block + + max_new_tokens_for_request = 0 + if is_last_chunk: + # This is the last chunk - reserve full max_new_tokens (SGLang behavior) + if hasattr(request, 'sampling_params') and request.sampling_params and request.sampling_params.max_tokens: + max_new_tokens_for_request = request.sampling_params.max_tokens + else: + max_new_tokens_for_request = self.config.model_config.max_model_len - request.need_prefill_tokens + max_new_tokens_for_request = min(max_new_tokens_for_request, self.clip_max_new_tokens_estimation) + + # 3. Tokens reserved for all existing decode requests (ratio-based) + decode_reserved_tokens = self._calculate_decode_reserved_tokens_by_ratio() + + # Sum all tokens and convert to blocks once at the end + total_tokens = required_tokens_for_prefill + max_new_tokens_for_request + decode_reserved_tokens + can_schedule_block_num_threshold = ( + total_tokens + self.config.cache_config.block_size - 1 + ) // self.config.cache_config.block_size + + if self.config.speculative_config.method is not None: + can_schedule_block_num_threshold = min( + can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq + ) + + llm_logger.debug( + f"Prefill threshold: tokens={total_tokens:.1f} -> blocks={can_schedule_block_num_threshold} " + f"(prefill={required_tokens_for_prefill}, future_decode={max_new_tokens_for_request:.1f}, " + f"decode_reserved={decode_reserved_tokens:.1f}, is_last_chunk={is_last_chunk})" + ) + return can_schedule_block_num_threshold def _update_mm_hashes(self, request): @@ -448,16 +702,19 @@ def revert_chunked_mm_input(self, mm_inputs, matched_token_num): break return matched_token_num - def _get_num_new_tokens(self, request, token_budget): - # TODO: set condition to new _get_num_new_tokens + def _get_num_new_tokens(self, request, chunked_prefill_size, token_budget): + # SGLang-aligned: use min(chunked_prefill_size, token_budget) as the limit + # chunked_prefill_size is the max tokens for a single request (like SGLang's rem_chunk_tokens) + # token_budget (max_num_batched_tokens) is the total batch budget (like SGLang's rem_total_tokens) num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens - num_new_tokens = min(num_new_tokens, token_budget) + # SGLang logic: _rem_tokens = min(rem_chunk_tokens, rem_total_tokens) + num_new_tokens = min(num_new_tokens, chunked_prefill_size, token_budget) if ( current_platform.is_intel_hpu() - and request.need_prefill_tokens - request.num_computed_tokens > token_budget - and token_budget > self.config.cache_config.block_size + and request.need_prefill_tokens - request.num_computed_tokens > min(chunked_prefill_size, token_budget) + and min(chunked_prefill_size, token_budget) > self.config.cache_config.block_size ): - num_new_tokens = token_budget // self.config.cache_config.block_size * self.config.cache_config.block_size + num_new_tokens = min(chunked_prefill_size, token_budget) // self.config.cache_config.block_size * self.config.cache_config.block_size request.with_image = False if not self.config.model_config.enable_mm: @@ -663,6 +920,8 @@ def get_enough_request(request, scheduled_reqs): preempted_reqs: list[Request] = [] error_reqs: list[tuple[str, str]] = [] token_budget = self.config.scheduler_config.max_num_batched_tokens + # SGLang-aligned: chunked_prefill_size is per-request limit, token_budget is batch limit + chunked_prefill_size = self.config.scheduler_config.chunked_prefill_size # First, schedule the RUNNING requests. req_index = 0 @@ -674,6 +933,9 @@ def get_enough_request(request, scheduled_reqs): self.need_block_num_map[request.request_id] = SignalConsumer(need_block_num, 1) self.need_block_num_signal.value[request.idx] = 0 + # NOTE: The decode scheduling logic should be OUTSIDE of `if need_block_num != 0:` + # The need_block_num signal is only for handling worker-side block requests, + # but decode scheduling should always happen regardless of this signal if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding if ( self.config.scheduler_config.splitwise_role == "prefill" @@ -682,38 +944,76 @@ def get_enough_request(request, scheduled_reqs): continue if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens request.num_computed_tokens = request.num_total_tokens - 1 - if ( - self.allocated_slots(request) - request.num_total_tokens - <= self.config.cache_config.prealloc_dec_block_slot_num_threshold - ): - # Allocation for next decoding blocks - if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num): + + # SGLang-aligned: dynamically calculate how many blocks are needed for next decode + # Similar to SGLang's new_page_count_next_decode: + # If (num_total_tokens - 1) % block_size == 0, need 1 more block + block_size = self.config.cache_config.block_size + num_new_blocks_needed = 1 if (request.num_total_tokens - 1) % block_size == 0 else 0 + + # SGLang-aligned: schedule decode task without threshold check + # The pre-allocation is decoupled from scheduling + if num_new_blocks_needed > 0: + # Need to allocate new blocks, check if we can allocate + if self.cache_manager.can_allocate_gpu_blocks(num_new_blocks_needed): llm_logger.debug( f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}" ) request.block_tables.extend( self.cache_manager.allocate_gpu_blocks( - self.config.cache_config.enc_dec_block_num, request.request_id + num_new_blocks_needed, request.request_id ) ) # Prepare decoding task scheduled_reqs.append(self._prepare_decode_task(request)) else: - # Not enough blocks to allocate, trigger preemption - can_schedule = self._trigger_preempt( - request, self.config.cache_config.enc_dec_block_num, preempted_reqs, scheduled_reqs - ) - if not can_schedule: - break - # Allocation for next decoding blocks - request.block_tables.extend( - self.cache_manager.allocate_gpu_blocks( - self.config.cache_config.enc_dec_block_num, request.request_id + # Not enough blocks, trigger preemption + # SGLang-aligned: first try to evict decode KV cache + self._evict_decode_kv_cache(len(self.running)) + + # Check again after eviction + if self.cache_manager.can_allocate_gpu_blocks(num_new_blocks_needed): + request.block_tables.extend( + self.cache_manager.allocate_gpu_blocks( + num_new_blocks_needed, request.request_id + ) ) - ) - # Prepare decoding task - scheduled_reqs.append(self._prepare_decode_task(request)) - num_decoding_req_nums += 1 + scheduled_reqs.append(self._prepare_decode_task(request)) + else: + # Cannot allocate even after preemption, use SGLang-aligned behavior + # Try to preempt other requests + can_schedule = self._trigger_preempt( + request, num_new_blocks_needed, preempted_reqs, scheduled_reqs + ) + if not can_schedule: + # Cannot preempt (e.g., only 1 decode request left), + # skip this request and continue to avoid system hang + llm_logger.warning( + f"Cannot allocate {num_new_blocks_needed} blocks " + f"for decode request {request.request_id} (idx={request.idx}) " + f"even after preemption attempt. Request will wait for more resources." + ) + # Do NOT schedule this request - it will wait in the queue + # But still consume token_budget to avoid infinite loop + req_index += 1 + token_budget -= 1 + continue + + # Allocation for next decoding blocks after preemption + request.block_tables.extend( + self.cache_manager.allocate_gpu_blocks( + num_new_blocks_needed, request.request_id + ) + ) + # Prepare decoding task + scheduled_reqs.append(self._prepare_decode_task(request)) + + # No new blocks needed (num_new_blocks_needed == 0), but still schedule decode task + # SGLang-aligned: always schedule decode for each iteration + else: + scheduled_reqs.append(self._prepare_decode_task(request)) + + num_decoding_req_nums += 1 token_budget -= 1 if ( request.use_extend_tables @@ -742,13 +1042,9 @@ def _allocate_decode_and_extend(): request.extend_block_tables.extend( self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id) ) - scheduled_reqs.append( - ScheduledExtendBlocksTask( - idx=request.idx, - request_id=request.request_id, - extend_block_tables=request.extend_block_tables, - ) - ) + # EXTEND task is treated as PREFILL with 0 new tokens + # This ensures worker handles it correctly instead of treating as PREEMPTED + scheduled_reqs.append(self._prepare_prefill_task(request, 0)) llm_logger.debug(f"extend blocks is {request.extend_block_tables}") if self.cache_manager.can_allocate_gpu_blocks( @@ -787,7 +1083,7 @@ def _allocate_decode_and_extend(): if get_enough_request(request, scheduled_reqs): req_index += 1 continue - num_new_tokens = self._get_num_new_tokens(request, token_budget) + num_new_tokens = self._get_num_new_tokens(request, chunked_prefill_size, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(num_new_block): @@ -797,7 +1093,10 @@ def _allocate_decode_and_extend(): # Prepare prefill task scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) else: # Not enough blocks to allocate, trigger preemption - can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) + if self.config.scheduler_config.enable_priority_scheduling: + can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) + else: + can_schedule = False if not can_schedule: break request.block_tables.extend( @@ -860,7 +1159,7 @@ def _allocate_decode_and_extend(): ): continue # Allocate blocks for the tokens that does not hit cache - num_new_tokens = self._get_num_new_tokens(request, token_budget) + num_new_tokens = self._get_num_new_tokens(request, chunked_prefill_size, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( request, num_new_block @@ -913,7 +1212,7 @@ def _allocate_decode_and_extend(): break # Allocate blocks for the tokens that does not hit cache - num_new_tokens = self._get_num_new_tokens(request, token_budget) + num_new_tokens = self._get_num_new_tokens(request, chunked_prefill_size, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( request, num_new_block @@ -948,14 +1247,15 @@ def _allocate_decode_and_extend(): if scheduled_reqs: llm_logger.debug(f"schedued_reqs: {scheduled_reqs}") - self.current_reserve_output_block_num_float -= self.decay_output_block_num - self.current_reserve_output_block_num = max( - int(self.current_reserve_output_block_num_float), - self.min_reserve_output_block_num, - 0, - ) - if self.current_reserve_output_block_num == 0: - self.can_relax_prefill_strategy = True + + # New mechanism: decay new_token_ratio only when there are decode requests + has_decode_reqs = any(getattr(r, "task_type", None) == RequestType.DECODE for r in scheduled_reqs) + if has_decode_reqs: + self.current_new_token_ratio = max( + self.current_new_token_ratio - self.new_token_ratio_decay, + self.min_new_token_ratio, + ) + llm_logger.debug(f"NewTokenRatio decayed to {self.current_new_token_ratio:.4f}") if ( hasattr(self, "scheduler_metrics_logger") @@ -1008,6 +1308,10 @@ def _allocate_decode_and_extend(): use_cudagraph=use_decode_cudagraph, ) + # SGLang-aligned: reset new_token_ratio when completely idle + if not scheduled_reqs: + self.reset_new_token_ratio_on_idle() + self.update_metrics() return scheduled_reqs, error_reqs @@ -1383,7 +1687,6 @@ def finish_requests_async(self, request_ids: Union[str, Iterable[str]]): def finish_requests(self, request_ids: Union[str, Iterable[str]]): llm_logger.info(f"recycle resources for requests: {request_ids}") - self.update_metrics(verbose=True) try: if isinstance(request_ids, str): request_ids = (request_ids,) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index cdbdeb52255..59a265962d1 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -192,16 +192,12 @@ "FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))), # Whether to use phi FP8 quantization,if 1,use paddle default. "FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))), - # Reserve output blocks for decoding requests when schedule new prefill requests - "FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( - os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16") - ), - "FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: float( - os.getenv("FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "0.025") - ), - "FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( - os.getenv("FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "0") - ), + # Reserve tokens when schedule + "FD_INIT_NEW_TOKEN_RATIO": lambda: float(os.getenv("FD_INIT_NEW_TOKEN_RATIO", "0.7")), + "FD_MIN_NEW_TOKEN_RATIO_FACTOR": lambda: float(os.getenv("FD_MIN_NEW_TOKEN_RATIO_FACTOR", "0.14")), + "FD_NEW_TOKEN_RATIO_DECAY_STEPS": lambda: int(os.getenv("FD_NEW_TOKEN_RATIO_DECAY_STEPS", "600")), + "FD_RETRACT_DECODE_STEPS": lambda: int(os.getenv("FD_RETRACT_DECODE_STEPS", "20")), + "FD_CLIP_MAX_NEW_TOKENS_ESTIMATION": lambda: int(os.getenv("FD_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")), # Timeout for worker process health check in seconds "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), # File path for file storage backend diff --git a/fastdeploy/scheduler/config.py b/fastdeploy/scheduler/config.py index 7e56eec676b..8e2eb6f8922 100644 --- a/fastdeploy/scheduler/config.py +++ b/fastdeploy/scheduler/config.py @@ -270,9 +270,11 @@ def __init__(self, args): self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler self.max_num_batched_tokens = 2048 # base token_num for text inputs self.max_extra_num_batched_tokens = 16384 # extra token_num for multimodal inputs + self.chunked_prefill_size = 8192 # SGLang-aligned: dynamically calculated based on GPU memory self.max_num_seqs = 34 self.splitwise_role = "mixed" self.enable_overlap_schedule = False + self.enable_priority_scheduling = False self.config = None for key, value in args.items():