diff --git a/common/common.cpp b/common/common.cpp index b01772e1cbfe..c03530d10c88 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2046,6 +2046,7 @@ void common_prompt_checkpoint::clear() { pos_min = 0; pos_max = 0; + pos_end = 0; data_tgt.clear(); data_dft.clear(); @@ -2054,10 +2055,12 @@ void common_prompt_checkpoint::clear() { void common_prompt_checkpoint::update_pos( int64_t n_tokens, llama_pos pos_min, - llama_pos pos_max) { + llama_pos pos_max, + llama_pos pos_end) { this->n_tokens = n_tokens; this->pos_min = pos_min; this->pos_max = pos_max; + this->pos_end = pos_end; } void common_prompt_checkpoint::update_tgt( diff --git a/common/common.h b/common/common.h index 0b284cbb36c7..f7a0aad12360 100644 --- a/common/common.h +++ b/common/common.h @@ -1060,6 +1060,7 @@ struct common_prompt_checkpoint { llama_pos pos_min; llama_pos pos_max; + llama_pos pos_end; std::vector data_tgt; std::vector data_dft; @@ -1072,7 +1073,8 @@ struct common_prompt_checkpoint { void update_pos( int64_t n_tokens, llama_pos pos_min, - llama_pos pos_max); + llama_pos pos_max, + llama_pos pos_end); void update_tgt( llama_context * ctx, diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index d87ba48beb14..dab33606385e 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -172,7 +172,8 @@ int main(int argc, char ** argv) { ckpt.update_pos( prompt_tgt.size(), llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id), - llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id)); + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id), + prompt_tgt.size()); if (use_ckpt_dft) { ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index bcae39a10966..12896f8c4276 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -102,6 +103,7 @@ struct server_slot { int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; + int32_t n_prompt_tokens_prefix = -1; size_t last_nl_pos = 0; @@ -206,6 +208,7 @@ struct server_slot { SLT_DBG(*this, "%s", "\n"); n_prompt_tokens_cache = 0; + n_prompt_tokens_prefix = -1; last_nl_pos = 0; generated_text = ""; @@ -2120,13 +2123,43 @@ struct server_context_impl { // n_tokens_cur: the number of tokens added to the batch for the current slot void create_checkpoint(server_slot & slot, const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) { while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); + // Preserve early prefix anchors and the most recent checkpoints. Removing a checkpoint + // from the densest interior interval keeps coverage across the full prompt history. + auto erase_it = slot.prompt.checkpoints.begin(); + + if (slot.prompt.checkpoints.size() > 4) { + const size_t n_keep_front = 2; + const size_t n_keep_back = 2; + + auto prev_it = std::next(slot.prompt.checkpoints.begin(), n_keep_front - 1); + auto cur_it = std::next(prev_it); + auto next_it = std::next(cur_it); + auto last_candidate = std::prev(slot.prompt.checkpoints.end(), n_keep_back); + + erase_it = cur_it; + int64_t min_merged_span = std::numeric_limits::max(); + while (cur_it != last_candidate) { + const int64_t merged_span = next_it->n_tokens - prev_it->n_tokens; + + if (merged_span < min_merged_span) { + erase_it = cur_it; + min_merged_span = merged_span; + } + + ++prev_it; + ++cur_it; + ++next_it; + } + } else if (slot.prompt.checkpoints.size() > 2) { + erase_it = std::next(slot.prompt.checkpoints.begin()); + } + + const auto & cur = *erase_it; - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); + SLT_WRN(slot, "erasing redundant context checkpoint (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, cur.pos_end, cur.n_tokens, (float) cur.size() / 1024 / 1024); - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + slot.prompt.checkpoints.erase(erase_it); } auto & cur = slot.prompt.checkpoints.emplace_back(); @@ -2134,15 +2167,16 @@ struct server_context_impl { // [TAG_CHECKPOINTS_FIX_POS_MIN] // TODO: here we incorrectly deterimne that the saved checkpoint data covers the [pos_min, pos_max] range // this is not true for SWA models: https://github.com/ggml-org/llama.cpp/pull/24411#issuecomment-4677983225 - cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max); + const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; + cur.update_pos(n_tokens_start, pos_min, pos_max, slot.prompt.tokens.pos_next(n_tokens_start)); cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); SLT_INF(slot, - "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, - cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); + cur.pos_max, cur.pos_end, cur.n_tokens, (float) cur.size() / 1024 / 1024); } void process_single_task(server_task && task) { @@ -2610,7 +2644,8 @@ struct server_context_impl { slot.spec_ckpt.update_pos( slot.prompt.n_tokens(), llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id), - llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id)); + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id), + slot.prompt.tokens.pos_next()); if (use_ckpt_dft) { slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); @@ -2882,6 +2917,17 @@ struct server_context_impl { } llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); + const llama_pos prefix_end = pos_next; + const bool is_recurrent_or_hybrid = + llama_model_is_recurrent(model_tgt) || + llama_model_is_hybrid(model_tgt); + const bool needs_context_checkpoints = + params_base.n_ctx_checkpoints > 0 && + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS || + n_swa > 0); + slot.n_prompt_tokens_prefix = + needs_context_checkpoints && n_past > 0 ? n_past : -1; // ref: https://github.com/ggml-org/llama.cpp/pull/24110 const bool has_new_tokens = (n_past < slot.task->n_tokens()); @@ -2946,9 +2992,19 @@ struct server_context_impl { slot.prompt.checkpoints.rend(), [&, func_name = __func__](const auto & cur) { // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] - LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12, - func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold); + LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d], pos_end = %d against prefix_end = %d, threshold = %d...\n", 12, + func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, cur.pos_end, prefix_end, pos_min_thold); + + // A checkpoint is reusable only if the entire saved state belongs to the + // common prompt prefix. Memory ranges alone are not sufficient for recurrent + // and hybrid models because they do not necessarily cover the full prefix. + if (cur.pos_end > prefix_end) { + return false; + } // workaround for [TAG_CHECKPOINTS_FIX_POS_MIN] + if (is_recurrent_or_hybrid) { + return cur.pos_max < pos_next || cur.pos_min == 0; + } if (cur.pos_max > pos_next) { return false; } @@ -2965,7 +3021,7 @@ struct server_context_impl { pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->pos_end, it->n_tokens, n_past, (float) it->size() / 1024 / 1024); } if (do_reset) { @@ -2978,11 +3034,16 @@ struct server_context_impl { } { - // erase any checkpoints with pos_max > pos_next + // erase checkpoints that extend beyond the matching prefix. Dense models also + // invalidate checkpoints beyond the restored memory range, but recurrent/hybrid + // checkpoints can legitimately have a larger pos_max while still belonging to + // the same prompt prefix. for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; - if (cur.pos_max > pos_next) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024); + const bool invalid_prefix = cur.pos_end > prefix_end; + const bool invalid_memory = !is_recurrent_or_hybrid && cur.pos_max > pos_next; + if (invalid_prefix || invalid_memory) { + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", n_swa = %d, prefix_end = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.pos_end, cur.n_tokens, n_swa, prefix_end, pos_next, (float) cur.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else { ++it; @@ -3065,6 +3126,25 @@ struct server_context_impl { ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS || n_swa > 0); + // Anchor a checkpoint at the already processed common-prefix boundary. This + // is done before adding more tokens because checkpoints capture the current + // decoded memory state, not the batch that is about to be decoded. + if (do_checkpoint && + slot.n_prompt_tokens_prefix > 0 && + slot.prompt.n_tokens() == slot.n_prompt_tokens_prefix) { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); + const bool is_spaced = + slot.prompt.checkpoints.empty() || + slot.prompt.n_tokens() > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step; + + if (pos_min >= 0 && is_spaced) { + create_checkpoint(slot, 0, pos_min, pos_max); + } + + slot.n_prompt_tokens_prefix = -1; + } + bool has_mtmd = false; // check if we should process the image @@ -3077,6 +3157,20 @@ struct server_context_impl { break; } + // Preserve the reusable prefix before entering a media chunk. Checkpoints + // created after MTMD processing are not safe to restore. + if (do_checkpoint && !has_mtmd) { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); + const bool is_spaced = + slot.prompt.checkpoints.empty() || + cur_token_idx > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step; + + if (pos_min >= 0 && is_spaced) { + create_checkpoint(slot, 0, pos_min, pos_max); + } + } + // process the image size_t n_tokens_out = 0; int32_t res = slot.process_mtmd_chunk(cur_token_idx, n_tokens_out); @@ -3162,6 +3256,7 @@ struct server_context_impl { const auto n_tokens_cur = batch.n_tokens - n_tokens_prev; const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch; + const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { @@ -3186,10 +3281,6 @@ struct server_context_impl { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); - // checkpoints are created before the current batch is decoded, so - // their token position is the batch start rather than the prompt end - const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; - { const bool is_on_user = n_before_user_known &&