Skip to content

feat(scheduler): implement SGLang-style dynamic resource management and preemption#6606

Open
Foriv wants to merge 3 commits intoPaddlePaddle:developfrom
Foriv:feature/dynamic-resource-management
Open

feat(scheduler): implement SGLang-style dynamic resource management and preemption#6606
Foriv wants to merge 3 commits intoPaddlePaddle:developfrom
Foriv:feature/dynamic-resource-management

Conversation

@Foriv
Copy link

@Foriv Foriv commented Mar 2, 2026

功能概述

实现 SGLang 风格的动态资源管理机制,包括:

  • 动态 token 预留策略 (new_tokens_ratio)
  • 智能抢占排序和 KV Cache 驱逐
  • 优先级调度参数
  • Ratio 连续性管理

主要改动

1. 动态 Token 预留机制 (new_tokens_ratio)

核心逻辑

  • Prefill 准入: 强制预留全量 max_new_tokens 对应的 Blocks
  • Decode 动态预留: 按 max_new_tokens * new_tokens_ratio 预留
  • 逐步衰减: 每次 Decode Forward 后 ratio 线性衰减,释放闲置资源
  • 空闲重置: 只在完全空闲(无 prefill 且无 decode)时重置 ratio

新增环境变量

FD_INIT_NEW_TOKEN_RATIO = 0.7          # 初始预留比例
FD_MIN_NEW_TOKEN_RATIO_FACTOR = 0.14   # 最小预留比例
FD_NEW_TOKEN_RATIO_DECAY = 0.003       # 衰减速率
FD_RETRACT_DECODE_STEPS = 20           # 抢占时回退步数

工作流程

  1. 准入阶段: 计算所需资源
    required_blocks = prefill_blocks(full) + Σ(decode_blocks * ratio)
    
  2. 调度阶段: Decode 请求按 ratio 比例预留
  3. 衰减阶段: 每次 Forward 后更新
    ratio = max(ratio - decay, min_ratio)
    
  4. 重置阶段: 当 running=0 && waiting=0 时重置为初始值

2. 抢占策略对齐 SGLang

抢占排序策略

优先抢占短输出、长输入的请求:

sorted_requests = sorted(
    self.running, 
    key=lambda r: (
        r.num_computed_tokens - r.need_prefill_tokens,  # 输出长度(短优先)
        -r.need_prefill_tokens                          # 输入长度(长优先)
    ), 
    reverse=True
)

KV Cache 驱逐

新增 _evict_decode_kv_cache 方法,按 SGLang 策略驱逐 KV Cache:

# 每次驱逐 retract_decode_steps * remaining_req_count 个 token
num_tokens_to_evict = retract_decode_steps * len(remaining_decode_reqs)
num_blocks_to_free = num_tokens_to_evict // block_size

保护机制

  • 至少保留 1 个 decode 请求不被抢占
  • 被抢占请求放到 waiting 队尾,保持 FIFO 公平性

Ratio 更新

_update_new_token_ratio_after_preemption 只统计 decode 请求:

decode_reqs = [r for r in self.running if r.num_computed_tokens >= r.need_prefill_tokens]

3. 优先级调度参数

新增 --enable-priority-scheduling 参数,支持 prefill 抢占 decode:

python -m fastdeploy.entrypoints.api_server \
    --enable-priority-scheduling \
    ...
  • enable_priority_scheduling=True 时,prefill 请求资源不足可触发抢占
  • 默认值: False (与 SGLang 对齐)

参数传递链路

EngineArgs.enable_priority_scheduling (args_utils.py)
    ↓
create_scheduler_config() (args_utils.py)
    ↓
SchedulerConfig (scheduler/config.py)
    ↓
ResourceManagerV1 (resource_manager_v1.py)

4. Ratio 连续性管理

新增 reset_new_token_ratio_on_idle 方法,优化 ratio 重置逻辑:

def reset_new_token_ratio_on_idle(self):
    """只在完全空闲时重置 ratio"""
    if len(self.running) == 0 and len(self.waiting) == 0:
        self.new_tokens_ratio = self.init_new_token_ratio
  • 调用时机: 在 schedule() 方法中,当 scheduled_reqs 为空时调用
  • 效果: ratio 保持连续性,避免状态跳变

涉及文件

fastdeploy/
├── envs.py                                 # 新增环境变量定义
├── engine/
│   ├── args_utils.py                       # 新增启动参数定义
│   └── sched/
│       └── resource_manager_v1.py          # 实现核心调度逻辑
├── scheduler/
│   └── config.py                           # SchedulerConfig 新增参数
└── docs/
    └── features/
        └── new_tokens_ratio_scheduler.md   # 新增技术方案文档

与 SGLang 的对比

特性 SGLang FastDeploy (本 PR)
动态预留 new_token_ratio 完全对齐
抢占排序 短输出优先 完全对齐
KV 驱逐 retract_decode_steps 完全对齐
优先级调度 enable_priority_scheduling 完全对齐
Ratio 连续性 空闲时重置 完全对齐
FIFO 公平性 队尾插入 完全对齐

核心代码逻辑

动态预留计算

def _calculate_reserved_blocks(self, request):
    if request.is_prefill:
        # Prefill: 全量预留
        return request.max_new_tokens // self.block_size
    else:
        # Decode: 动态预留
        return int(request.max_new_tokens * self.new_tokens_ratio) // self.block_size

抢占排序

def _trigger_preempt(self, need_blocks):
    # 只对 decode 请求排序
    decode_reqs = [r for r in self.running if r.is_decode]
    
    # 短输出、长输入优先抢占
    sorted_reqs = sorted(
        decode_reqs,
        key=lambda r: (r.output_len, -r.input_len),
        reverse=True
    )
    
    # 保留至少 1 个请求
    for req in sorted_reqs[:-1]:
        self._evict_decode_kv_cache(req)
        self.waiting.append(req)  # 队尾插入

Ratio 衰减

def _decay_new_token_ratio(self):
    if self.has_decode_requests:
        self.new_tokens_ratio = max(
            self.new_tokens_ratio - self.decay_rate,
            self.min_new_token_ratio
        )

Foriv added 3 commits March 2, 2026 10:15
- Fix deadlock when 1 decode request remains
- Use sync free_block_ids for block release
- Fix preemption failure loop handling (continue vs break)
- Align block pre-allocation with SGLang
- Ensure decode tasks always scheduled
- Fix origin_input_ids AttributeError
@paddle-bot
Copy link

paddle-bot bot commented Mar 2, 2026

Thanks for your contribution!

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


Foriv seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@paddle-bot paddle-bot bot added the contributor External developers label Mar 2, 2026
@Jiang-Jia-Jun Jiang-Jia-Jun requested a review from Copilot March 2, 2026 13:50
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 在 FastDeploy 的 V1 调度/资源管理链路中引入了对齐 SGLang 的动态资源管理与抢占机制,核心围绕“按 new_token_ratio 动态预留 decode 资源、触发抢占与 KV cache 驱逐、以及可选的优先级调度开关”。

Changes:

  • 新增/替换基于 new_token_ratio 的 decode 资源预留、衰减与 idle 重置逻辑,并引入对应环境变量。
  • 调整抢占逻辑为“仅抢占 decode 请求 + SGLang 风格排序 + KV cache 驱逐 + 抢占后 ratio 更新”,并将被抢占请求改为 FIFO 入队尾。
  • 新增 --enable-priority-scheduling 参数,并放宽部分默认限制(如 max_num_seqs 上限、max_num_batched_tokens 默认值)。

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
fastdeploy/scheduler/config.py SchedulerConfig 增加 enable_priority_scheduling 默认值,用于调度侧开关透传。
fastdeploy/envs.py 移除旧的固定 block 预留 env,新增 new_token_ratio 相关 env。
fastdeploy/engine/sched/resource_manager_v1.py 落地 new_token_ratio 预留/衰减/重置、SGLang 风格抢占与 KV 驱逐、以及调度主循环的 decode/prefill 行为调整。
fastdeploy/engine/args_utils.py 增加 CLI 参数 --enable-priority-scheduling,并调整默认 max_num_batched_tokens
fastdeploy/config.py 放宽 max_num_seqs 上限检查(256 -> 512)。

Comment on lines 1451 to 1453
else:
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
self.max_num_batched_tokens = 16384 # if set to max_model_len, it's easy to be OOM
else:
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 标题目前是 feat(scheduler): ...,不符合仓库要求的 [CLASS]Title 格式(例如 [Feature] ... / [BugFix] ...)。建议将标题调整为符合格式的形式,便于后续 changelog/自动化分类。

Copilot uses AI. Check for mistakes.
Comment on lines +332 to +334
decode_requests.sort(
key=lambda r: (len(r.output_token_ids), -r.prompt_token_ids_len),
reverse=True, # pop from end: shorter output first
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_trigger_preempt() 里的排序与注释/预期不一致:当前 decode_requests.sort(..., reverse=True) 会把 output_len 更长的请求排在前面,但后续是按 for preempted_req in decode_requests[:-1] 从头开始抢占,结果会优先抢占“长输出”而不是“短输出”。建议改为按 output_len 升序(或保持 reverse=True 但从尾部 pop/反向遍历),确保真正做到“短输出优先、长输入优先”。

Suggested change
decode_requests.sort(
key=lambda r: (len(r.output_token_ids), -r.prompt_token_ids_len),
reverse=True, # pop from end: shorter output first
# Sort ascending by (output_len, -input_len) so that decode_requests[0]
# is the highest-priority candidate for preemption.
decode_requests.sort(
key=lambda r: (len(r.output_token_ids), -r.prompt_token_ids_len),

Copilot uses AI. Check for mistakes.
"FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int(
os.getenv("FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "0")
),
# Reserve tokens when schedule
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释 # Reserve tokens when schedule 有多余空格且语法不太通顺,建议改为更明确的英文(例如 “Reserve tokens when scheduling new requests”),同时去掉多余空格,便于后续维护。

Suggested change
# Reserve tokens when schedule
# Reserve tokens when scheduling new requests

Copilot uses AI. Check for mistakes.
self.max_num_seqs = 34
self.splitwise_role = "mixed"
self.enable_overlap_schedule = False
self.enable_priority_scheduling = False
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_priority_scheduling = False 这一行末尾有多余空格,会触发格式化/linters(pre-commit)告警。建议去掉行尾空格。

Suggested change
self.enable_priority_scheduling = False
self.enable_priority_scheduling = False

Copilot uses AI. Check for mistakes.
self.max_num_batched_tokens = self.max_model_len
else:
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
self.max_num_batched_tokens = 16384 # if set to max_model_len, it's easy to be OOM
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认 max_num_batched_tokens 从 8192 提升到 16384 会显著增加单 batch token 上限,可能提高显存峰值/触发 OOM(代码注释本身也提示该参数与 OOM 相关)。同时 fastdeploy/config.py 的默认值仍是 8192,导致“从 EngineArgs 启动”和“从 FDConfig.postprocess 推导”的默认策略不一致。建议:1)补充变更原因/评估;2)考虑通过配置项或 env 控制;3)至少保持两个默认来源一致。

Suggested change
self.max_num_batched_tokens = 16384 # if set to max_model_len, it's easy to be OOM
# Use a conservative default to avoid OOM and keep consistent with FDConfig.
# Do not set to max_model_len by default because it will significantly increase peak memory usage.
self.max_num_batched_tokens = 8192

Copilot uses AI. Check for mistakes.
Comment on lines +216 to +218
self.new_token_ratio_decay = (
self.init_new_token_ratio - self.min_new_token_ratio
) / envs.FD_NEW_TOKEN_RATIO_DECAY_STEPS
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.new_token_ratio_decay = (...) / envs.FD_NEW_TOKEN_RATIO_DECAY_STEPS 在环境变量被设置为 0 或负数时会触发除零/异常。建议对 FD_NEW_TOKEN_RATIO_DECAY_STEPS 做下界校验(例如最小为 1),或在 steps<=0 时将 decay 设为 0 并给出告警。

Suggested change
self.new_token_ratio_decay = (
self.init_new_token_ratio - self.min_new_token_ratio
) / envs.FD_NEW_TOKEN_RATIO_DECAY_STEPS
decay_steps = envs.FD_NEW_TOKEN_RATIO_DECAY_STEPS
if decay_steps is None or decay_steps <= 0:
# Avoid division by zero or negative values; fall back to no decay
self.new_token_ratio_decay = 0.0
llm_logger.warning(
"FD_NEW_TOKEN_RATIO_DECAY_STEPS is not a positive value "
f"({decay_steps}); setting new_token_ratio_decay to 0.0."
)
else:
self.new_token_ratio_decay = (
self.init_new_token_ratio - self.min_new_token_ratio
) / decay_steps

Copilot uses AI. Check for mistakes.
Comment on lines +409 to +410
# Default retract_decode_steps = 20 (SGLang default)
retract_decode_steps = getattr(self, 'retract_decode_steps', 20)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FD_RETRACT_DECODE_STEPS 已在 envs 中新增,但这里用 getattr(self, 'retract_decode_steps', 20),而 __init__ 中并没有给 self.retract_decode_steps 赋值,导致环境变量配置实际不生效(永远是 20)。建议在 __init__ 设置 self.retract_decode_steps = envs.FD_RETRACT_DECODE_STEPS,并在这里直接使用该字段。

Suggested change
# Default retract_decode_steps = 20 (SGLang default)
retract_decode_steps = getattr(self, 'retract_decode_steps', 20)
# Default retract_decode_steps comes from envs (aligned with SGLang default 20)
self.retract_decode_steps = getattr(
self,
'retract_decode_steps',
envs.FD_RETRACT_DECODE_STEPS,
)
retract_decode_steps = self.retract_decode_steps

Copilot uses AI. Check for mistakes.
Comment on lines +589 to +594
# Sum all tokens and convert to blocks once at the end
total_tokens = required_tokens_for_prefill + max_new_tokens_for_request + decode_reserved_tokens
can_schedule_block_num_threshold = (
total_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_calculate_decode_reserved_tokens_by_ratio() 返回的是 float(reserved_tokens 按 ratio 计算),导致 total_tokens 也是 float。这里用 // 计算 blocks 时会得到 float 结果(例如 3.0),而 can_allocate_gpu_blocks() 期望的是 int;同时 floor division 也可能低估所需 blocks。建议在汇总后用 math.ceil(total_tokens / block_size) 得到 int blocks,或先将 decode_reserved_tokens 明确 round/ceil 成 int。

Copilot uses AI. Check for mistakes.
Comment on lines +196 to +199
"FD_INIT_NEW_TOKEN_RATIO": lambda: float(os.getenv("FD_INIT_NEW_TOKEN_RATIO", "0.7")),
"FD_MIN_NEW_TOKEN_RATIO_FACTOR": lambda: float(os.getenv("FD_MIN_NEW_TOKEN_RATIO_FACTOR", "0.14")),
"FD_NEW_TOKEN_RATIO_DECAY_STEPS": lambda: int(os.getenv("FD_NEW_TOKEN_RATIO_DECAY_STEPS", "600")),
"FD_RETRACT_DECODE_STEPS": lambda: int(os.getenv("FD_RETRACT_DECODE_STEPS", "20")),
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的环境变量/语义与 PR 描述不一致:描述里给的是 FD_NEW_TOKEN_RATIO_DECAY(按步衰减率),但代码引入的是 FD_NEW_TOKEN_RATIO_DECAY_STEPS(衰减总步数)并在 ResourceManagerV1 里按 (init-min)/steps 推导 decay。建议统一 PR 描述与实际实现(变量名与含义),避免用户按描述配置却不生效。

Copilot uses AI. Check for mistakes.
"""
assert self.scheduler_config.max_num_seqs <= 256, (
assert self.scheduler_config.max_num_seqs <= 512, (
"The parameter `max_num_seqs` is not allowed to exceed 256, "
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把 max_num_seqs 上限改成了 512,但断言错误信息仍写的是 “not allowed to exceed 256”。建议同步更新错误信息(包括数字 256 -> 512),否则用户排查配置问题时会被误导。

Suggested change
"The parameter `max_num_seqs` is not allowed to exceed 256, "
"The parameter `max_num_seqs` is not allowed to exceed 512, "

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants