From 2aabf8a35f012930b3bd817da14e42814d71f56b Mon Sep 17 00:00:00 2001 From: Regrad Date: Tue, 2 Jun 2026 19:33:38 +0300 Subject: [PATCH 1/7] server: improve checkpoint reuse heuristics for recurrent/hybrid models --- common/common.cpp | 4 +++- common/common.h | 4 +++- tools/server/server-context.cpp | 18 +++++++++++++++--- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b01772e1cbfe..784b978089ed 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2054,10 +2054,12 @@ void common_prompt_checkpoint::clear() { void common_prompt_checkpoint::update_pos( int64_t n_tokens, llama_pos pos_min, - llama_pos pos_max) { + llama_pos pos_max, + llama_pos pos_end) { this->n_tokens = n_tokens; this->pos_min = pos_min; this->pos_max = pos_max; + this->pos_end = pos_end; } void common_prompt_checkpoint::update_tgt( diff --git a/common/common.h b/common/common.h index 0b284cbb36c7..f7a0aad12360 100644 --- a/common/common.h +++ b/common/common.h @@ -1060,6 +1060,7 @@ struct common_prompt_checkpoint { llama_pos pos_min; llama_pos pos_max; + llama_pos pos_end; std::vector data_tgt; std::vector data_dft; @@ -1072,7 +1073,8 @@ struct common_prompt_checkpoint { void update_pos( int64_t n_tokens, llama_pos pos_min, - llama_pos pos_max); + llama_pos pos_max, + llama_pos pos_end); void update_tgt( llama_context * ctx, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index bcae39a10966..438592b626fb 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2134,7 +2134,8 @@ struct server_context_impl { // [TAG_CHECKPOINTS_FIX_POS_MIN] // TODO: here we incorrectly deterimne that the saved checkpoint data covers the [pos_min, pos_max] range // this is not true for SWA models: https://github.com/ggml-org/llama.cpp/pull/24411#issuecomment-4677983225 - cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max); + const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; + cur.update_pos(n_tokens_start, pos_min, pos_max, slot.prompt.tokens.pos_next(n_tokens_start)); cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); @@ -2610,7 +2611,8 @@ struct server_context_impl { slot.spec_ckpt.update_pos( slot.prompt.n_tokens(), llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id), - llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id)); + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id), + slot.prompt.tokens.pos_next()); if (use_ckpt_dft) { slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); @@ -2882,6 +2884,10 @@ struct server_context_impl { } llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); + const llama_pos prompt_end = slot.task->tokens.pos_next(); + const bool is_recurrent_or_hybrid = + llama_model_is_recurrent(model_tgt) || + llama_model_is_hybrid(model_tgt); // ref: https://github.com/ggml-org/llama.cpp/pull/24110 const bool has_new_tokens = (n_past < slot.task->n_tokens()); @@ -2945,10 +2951,16 @@ struct server_context_impl { slot.prompt.checkpoints.rbegin(), slot.prompt.checkpoints.rend(), [&, func_name = __func__](const auto & cur) { + if (cur.pos_end > prompt_end) { + return false; + } // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12, func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold); // workaround for [TAG_CHECKPOINTS_FIX_POS_MIN] + if (is_recurrent_or_hybrid) { + return cur.pos_max < pos_next || cur.pos_min == 0; + } if (cur.pos_max > pos_next) { return false; } @@ -2981,7 +2993,7 @@ struct server_context_impl { // erase any checkpoints with pos_max > pos_next for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; - if (cur.pos_max > pos_next) { + if (cur.pos_end > prompt_end || cur.pos_max > pos_next) { SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else { From 74eaaaba5489a71c664482963adf95f0bd5b0554 Mon Sep 17 00:00:00 2001 From: Regrad Date: Mon, 15 Jun 2026 14:20:55 +0300 Subject: [PATCH 2/7] server: retain prompt checkpoint coverage --- common/common.cpp | 1 + tools/server/server-context.cpp | 62 ++++++++++++++++++++++++--------- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 784b978089ed..c03530d10c88 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2046,6 +2046,7 @@ void common_prompt_checkpoint::clear() { pos_min = 0; pos_max = 0; + pos_end = 0; data_tgt.clear(); data_dft.clear(); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 438592b626fb..8c318f5c574c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2120,13 +2120,34 @@ struct server_context_impl { // n_tokens_cur: the number of tokens added to the batch for the current slot void create_checkpoint(server_slot & slot, const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) { while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); + // Preserve the oldest anchor and the most recent checkpoint. Removing the checkpoint + // from the densest interior interval keeps coverage across the full prompt history. + size_t erase_idx = 0; + + if (slot.prompt.checkpoints.size() > 2) { + erase_idx = 1; + int64_t min_merged_span = + slot.prompt.checkpoints[2].n_tokens - + slot.prompt.checkpoints[0].n_tokens; + + for (size_t i = 2; i + 1 < slot.prompt.checkpoints.size(); ++i) { + const int64_t merged_span = + slot.prompt.checkpoints[i + 1].n_tokens - + slot.prompt.checkpoints[i - 1].n_tokens; + + if (merged_span < min_merged_span) { + erase_idx = i; + min_merged_span = merged_span; + } + } + } + + const auto & cur = slot.prompt.checkpoints[erase_idx]; - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); + SLT_WRN(slot, "erasing redundant context checkpoint (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, cur.pos_end, cur.n_tokens, (float) cur.size() / 1024 / 1024); - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin() + erase_idx); } auto & cur = slot.prompt.checkpoints.emplace_back(); @@ -2141,9 +2162,9 @@ struct server_context_impl { cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); SLT_INF(slot, - "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, - cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); + cur.pos_max, cur.pos_end, cur.n_tokens, (float) cur.size() / 1024 / 1024); } void process_single_task(server_task && task) { @@ -2884,7 +2905,7 @@ struct server_context_impl { } llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); - const llama_pos prompt_end = slot.task->tokens.pos_next(); + const llama_pos prefix_end = pos_next; const bool is_recurrent_or_hybrid = llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt); @@ -2951,12 +2972,16 @@ struct server_context_impl { slot.prompt.checkpoints.rbegin(), slot.prompt.checkpoints.rend(), [&, func_name = __func__](const auto & cur) { - if (cur.pos_end > prompt_end) { + // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] + LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d], pos_end = %d against prefix_end = %d, threshold = %d...\n", 12, + func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, cur.pos_end, prefix_end, pos_min_thold); + + // A checkpoint is reusable only if the entire saved state belongs to the + // common prompt prefix. Memory ranges alone are not sufficient for recurrent + // and hybrid models because they do not necessarily cover the full prefix. + if (cur.pos_end > prefix_end) { return false; } - // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] - LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12, - func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold); // workaround for [TAG_CHECKPOINTS_FIX_POS_MIN] if (is_recurrent_or_hybrid) { return cur.pos_max < pos_next || cur.pos_min == 0; @@ -2977,7 +3002,7 @@ struct server_context_impl { pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->pos_end, it->n_tokens, n_past, (float) it->size() / 1024 / 1024); } if (do_reset) { @@ -2990,11 +3015,16 @@ struct server_context_impl { } { - // erase any checkpoints with pos_max > pos_next + // erase checkpoints that extend beyond the matching prefix. Dense models also + // invalidate checkpoints beyond the restored memory range, but recurrent/hybrid + // checkpoints can legitimately have a larger pos_max while still belonging to + // the same prompt prefix. for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; - if (cur.pos_end > prompt_end || cur.pos_max > pos_next) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024); + const bool invalid_prefix = cur.pos_end > prefix_end; + const bool invalid_memory = !is_recurrent_or_hybrid && cur.pos_max > pos_next; + if (invalid_prefix || invalid_memory) { + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", n_swa = %d, prefix_end = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.pos_end, cur.n_tokens, n_swa, prefix_end, pos_next, (float) cur.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else { ++it; From b082bd0eb7a557bf748971c6e98bc20c7f249428 Mon Sep 17 00:00:00 2001 From: Regrad Date: Mon, 15 Jun 2026 15:07:31 +0300 Subject: [PATCH 3/7] server: adapt checkpoint retention to list storage --- tools/server/server-context.cpp | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 8c318f5c574c..cc88dc41add3 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2122,32 +2122,36 @@ struct server_context_impl { while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { // Preserve the oldest anchor and the most recent checkpoint. Removing the checkpoint // from the densest interior interval keeps coverage across the full prompt history. - size_t erase_idx = 0; + auto erase_it = slot.prompt.checkpoints.begin(); if (slot.prompt.checkpoints.size() > 2) { - erase_idx = 1; - int64_t min_merged_span = - slot.prompt.checkpoints[2].n_tokens - - slot.prompt.checkpoints[0].n_tokens; + auto prev_it = slot.prompt.checkpoints.begin(); + auto cur_it = std::next(prev_it); + auto next_it = std::next(cur_it); - for (size_t i = 2; i + 1 < slot.prompt.checkpoints.size(); ++i) { - const int64_t merged_span = - slot.prompt.checkpoints[i + 1].n_tokens - - slot.prompt.checkpoints[i - 1].n_tokens; + erase_it = cur_it; + int64_t min_merged_span = next_it->n_tokens - prev_it->n_tokens; + + while (std::next(next_it) != slot.prompt.checkpoints.end()) { + ++prev_it; + ++cur_it; + ++next_it; + + const int64_t merged_span = next_it->n_tokens - prev_it->n_tokens; if (merged_span < min_merged_span) { - erase_idx = i; + erase_it = cur_it; min_merged_span = merged_span; } } } - const auto & cur = slot.prompt.checkpoints[erase_idx]; + const auto & cur = *erase_it; SLT_WRN(slot, "erasing redundant context checkpoint (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.pos_end, cur.n_tokens, (float) cur.size() / 1024 / 1024); - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin() + erase_idx); + slot.prompt.checkpoints.erase(erase_it); } auto & cur = slot.prompt.checkpoints.emplace_back(); From 8f23fd66b2efcb4cf4ef83f5665ce74a66257ed6 Mon Sep 17 00:00:00 2001 From: Regrad Date: Mon, 15 Jun 2026 22:25:16 +0300 Subject: [PATCH 4/7] server: checkpoint the shared prompt prefix --- tools/server/server-context.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index cc88dc41add3..8eee1bd9bfca 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -102,6 +102,7 @@ struct server_slot { int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; + int32_t n_prompt_tokens_prefix = -1; size_t last_nl_pos = 0; @@ -206,6 +207,7 @@ struct server_slot { SLT_DBG(*this, "%s", "\n"); n_prompt_tokens_cache = 0; + n_prompt_tokens_prefix = -1; last_nl_pos = 0; generated_text = ""; @@ -2910,6 +2912,7 @@ struct server_context_impl { llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); const llama_pos prefix_end = pos_next; + slot.n_prompt_tokens_prefix = n_past > 0 ? n_past : -1; const bool is_recurrent_or_hybrid = llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt); @@ -3182,6 +3185,15 @@ struct server_context_impl { break; } + // Anchor a checkpoint at the common-prefix boundary. This avoids + // re-processing the gap to the nearest prompt-end checkpoint on + // subsequent requests that reuse the same prefix. + if (do_checkpoint && + slot.n_prompt_tokens_prefix > 0 && + slot.prompt.n_tokens() == slot.n_prompt_tokens_prefix) { + break; + } + // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. // create checkpoints that many tokens before the end of the prompt: // - 4 + n_ubatch From 67798d148004a23ab5fcb9f2c47f26428155228f Mon Sep 17 00:00:00 2001 From: Regrad Date: Mon, 15 Jun 2026 22:31:50 +0300 Subject: [PATCH 5/7] server: limit prefix checkpoints to recurrent models --- tools/server/server-context.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 8eee1bd9bfca..8a5671c3d0e9 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2912,10 +2912,11 @@ struct server_context_impl { llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); const llama_pos prefix_end = pos_next; - slot.n_prompt_tokens_prefix = n_past > 0 ? n_past : -1; const bool is_recurrent_or_hybrid = llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt); + slot.n_prompt_tokens_prefix = + is_recurrent_or_hybrid && n_past > 0 ? n_past : -1; // ref: https://github.com/ggml-org/llama.cpp/pull/24110 const bool has_new_tokens = (n_past < slot.task->n_tokens()); From 001ee99438043c9e07ad7632c8532bd35dabe409 Mon Sep 17 00:00:00 2001 From: Regrad Date: Tue, 16 Jun 2026 00:34:34 +0300 Subject: [PATCH 6/7] server: checkpoint before media chunks --- tools/server/server-context.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 8a5671c3d0e9..a69837eea1e9 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3127,6 +3127,20 @@ struct server_context_impl { break; } + // Preserve the reusable prefix before entering a media chunk. Checkpoints + // created after MTMD processing are not safe to restore. + if (do_checkpoint && !has_mtmd) { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); + const bool is_spaced = + slot.prompt.checkpoints.empty() || + cur_token_idx > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step; + + if (pos_min >= 0 && is_spaced) { + create_checkpoint(slot, 0, pos_min, pos_max); + } + } + // process the image size_t n_tokens_out = 0; int32_t res = slot.process_mtmd_chunk(cur_token_idx, n_tokens_out); From 9b11fe7e8b311f980795b3765912d60ebcfddc95 Mon Sep 17 00:00:00 2001 From: Regrad Date: Tue, 16 Jun 2026 10:58:49 +0300 Subject: [PATCH 7/7] server: anchor reusable prefix checkpoints --- .../speculative-simple/speculative-simple.cpp | 3 +- tools/server/server-context.cpp | 66 ++++++++++++------- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index d87ba48beb14..dab33606385e 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -172,7 +172,8 @@ int main(int argc, char ** argv) { ckpt.update_pos( prompt_tgt.size(), llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id), - llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id)); + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id), + prompt_tgt.size()); if (use_ckpt_dft) { ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index a69837eea1e9..12896f8c4276 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -2122,30 +2123,35 @@ struct server_context_impl { // n_tokens_cur: the number of tokens added to the batch for the current slot void create_checkpoint(server_slot & slot, const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) { while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // Preserve the oldest anchor and the most recent checkpoint. Removing the checkpoint + // Preserve early prefix anchors and the most recent checkpoints. Removing a checkpoint // from the densest interior interval keeps coverage across the full prompt history. auto erase_it = slot.prompt.checkpoints.begin(); - if (slot.prompt.checkpoints.size() > 2) { - auto prev_it = slot.prompt.checkpoints.begin(); + if (slot.prompt.checkpoints.size() > 4) { + const size_t n_keep_front = 2; + const size_t n_keep_back = 2; + + auto prev_it = std::next(slot.prompt.checkpoints.begin(), n_keep_front - 1); auto cur_it = std::next(prev_it); auto next_it = std::next(cur_it); + auto last_candidate = std::prev(slot.prompt.checkpoints.end(), n_keep_back); erase_it = cur_it; - int64_t min_merged_span = next_it->n_tokens - prev_it->n_tokens; - - while (std::next(next_it) != slot.prompt.checkpoints.end()) { - ++prev_it; - ++cur_it; - ++next_it; - + int64_t min_merged_span = std::numeric_limits::max(); + while (cur_it != last_candidate) { const int64_t merged_span = next_it->n_tokens - prev_it->n_tokens; if (merged_span < min_merged_span) { erase_it = cur_it; min_merged_span = merged_span; } + + ++prev_it; + ++cur_it; + ++next_it; } + } else if (slot.prompt.checkpoints.size() > 2) { + erase_it = std::next(slot.prompt.checkpoints.begin()); } const auto & cur = *erase_it; @@ -2915,8 +2921,13 @@ struct server_context_impl { const bool is_recurrent_or_hybrid = llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt); + const bool needs_context_checkpoints = + params_base.n_ctx_checkpoints > 0 && + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS || + n_swa > 0); slot.n_prompt_tokens_prefix = - is_recurrent_or_hybrid && n_past > 0 ? n_past : -1; + needs_context_checkpoints && n_past > 0 ? n_past : -1; // ref: https://github.com/ggml-org/llama.cpp/pull/24110 const bool has_new_tokens = (n_past < slot.task->n_tokens()); @@ -3115,6 +3126,25 @@ struct server_context_impl { ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS || n_swa > 0); + // Anchor a checkpoint at the already processed common-prefix boundary. This + // is done before adding more tokens because checkpoints capture the current + // decoded memory state, not the batch that is about to be decoded. + if (do_checkpoint && + slot.n_prompt_tokens_prefix > 0 && + slot.prompt.n_tokens() == slot.n_prompt_tokens_prefix) { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); + const bool is_spaced = + slot.prompt.checkpoints.empty() || + slot.prompt.n_tokens() > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step; + + if (pos_min >= 0 && is_spaced) { + create_checkpoint(slot, 0, pos_min, pos_max); + } + + slot.n_prompt_tokens_prefix = -1; + } + bool has_mtmd = false; // check if we should process the image @@ -3200,15 +3230,6 @@ struct server_context_impl { break; } - // Anchor a checkpoint at the common-prefix boundary. This avoids - // re-processing the gap to the nearest prompt-end checkpoint on - // subsequent requests that reuse the same prefix. - if (do_checkpoint && - slot.n_prompt_tokens_prefix > 0 && - slot.prompt.n_tokens() == slot.n_prompt_tokens_prefix) { - break; - } - // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. // create checkpoints that many tokens before the end of the prompt: // - 4 + n_ubatch @@ -3235,6 +3256,7 @@ struct server_context_impl { const auto n_tokens_cur = batch.n_tokens - n_tokens_prev; const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch; + const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { @@ -3259,10 +3281,6 @@ struct server_context_impl { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); - // checkpoints are created before the current batch is decoded, so - // their token position is the batch start rather than the prompt end - const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; - { const bool is_on_user = n_before_user_known &&