Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2046,6 +2046,7 @@ void common_prompt_checkpoint::clear() {

pos_min = 0;
pos_max = 0;
pos_end = 0;

data_tgt.clear();
data_dft.clear();
Expand All @@ -2054,10 +2055,12 @@ void common_prompt_checkpoint::clear() {
void common_prompt_checkpoint::update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max) {
llama_pos pos_max,
llama_pos pos_end) {
this->n_tokens = n_tokens;
this->pos_min = pos_min;
this->pos_max = pos_max;
this->pos_end = pos_end;
}

void common_prompt_checkpoint::update_tgt(
Expand Down
4 changes: 3 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ struct common_prompt_checkpoint {

llama_pos pos_min;
llama_pos pos_max;
llama_pos pos_end;

std::vector<uint8_t> data_tgt;
std::vector<uint8_t> data_dft;
Expand All @@ -1072,7 +1073,8 @@ struct common_prompt_checkpoint {
void update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max);
llama_pos pos_max,
llama_pos pos_end);

void update_tgt(
llama_context * ctx,
Expand Down
3 changes: 2 additions & 1 deletion examples/speculative-simple/speculative-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ int main(int argc, char ** argv) {
ckpt.update_pos(
prompt_tgt.size(),
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id),
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id));
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id),
prompt_tgt.size());

if (use_ckpt_dft) {
ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
Expand Down
129 changes: 110 additions & 19 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <exception>
#include <memory>
#include <filesystem>
#include <limits>
#include <utility>
#include <fstream>

Expand Down Expand Up @@ -102,6 +103,7 @@ struct server_slot {

int32_t n_prompt_tokens_cache = 0;
int32_t n_prompt_tokens_processed = 0;
int32_t n_prompt_tokens_prefix = -1;

size_t last_nl_pos = 0;

Expand Down Expand Up @@ -206,6 +208,7 @@ struct server_slot {
SLT_DBG(*this, "%s", "\n");

n_prompt_tokens_cache = 0;
n_prompt_tokens_prefix = -1;

last_nl_pos = 0;
generated_text = "";
Expand Down Expand Up @@ -2120,29 +2123,60 @@ struct server_context_impl {
// n_tokens_cur: the number of tokens added to the batch for the current slot
void create_checkpoint(server_slot & slot, const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
// Preserve early prefix anchors and the most recent checkpoints. Removing a checkpoint
// from the densest interior interval keeps coverage across the full prompt history.
auto erase_it = slot.prompt.checkpoints.begin();

if (slot.prompt.checkpoints.size() > 4) {
const size_t n_keep_front = 2;
const size_t n_keep_back = 2;

auto prev_it = std::next(slot.prompt.checkpoints.begin(), n_keep_front - 1);
auto cur_it = std::next(prev_it);
auto next_it = std::next(cur_it);
auto last_candidate = std::prev(slot.prompt.checkpoints.end(), n_keep_back);

erase_it = cur_it;
int64_t min_merged_span = std::numeric_limits<int64_t>::max();
while (cur_it != last_candidate) {
const int64_t merged_span = next_it->n_tokens - prev_it->n_tokens;

if (merged_span < min_merged_span) {
erase_it = cur_it;
min_merged_span = merged_span;
}

++prev_it;
++cur_it;
++next_it;
}
} else if (slot.prompt.checkpoints.size() > 2) {
erase_it = std::next(slot.prompt.checkpoints.begin());
}

const auto & cur = *erase_it;

SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
SLT_WRN(slot, "erasing redundant context checkpoint (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.pos_end, cur.n_tokens, (float) cur.size() / 1024 / 1024);

slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
slot.prompt.checkpoints.erase(erase_it);
}

auto & cur = slot.prompt.checkpoints.emplace_back();

// [TAG_CHECKPOINTS_FIX_POS_MIN]
// TODO: here we incorrectly deterimne that the saved checkpoint data covers the [pos_min, pos_max] range
// this is not true for SWA models: https://github.com/ggml-org/llama.cpp/pull/24411#issuecomment-4677983225
cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
cur.update_pos(n_tokens_start, pos_min, pos_max, slot.prompt.tokens.pos_next(n_tokens_start));

cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

SLT_INF(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
cur.pos_max, cur.pos_end, cur.n_tokens, (float) cur.size() / 1024 / 1024);
}

void process_single_task(server_task && task) {
Expand Down Expand Up @@ -2610,7 +2644,8 @@ struct server_context_impl {
slot.spec_ckpt.update_pos(
slot.prompt.n_tokens(),
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id),
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id));
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id),
slot.prompt.tokens.pos_next());

if (use_ckpt_dft) {
slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
Expand Down Expand Up @@ -2882,6 +2917,17 @@ struct server_context_impl {
}

llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
const llama_pos prefix_end = pos_next;
const bool is_recurrent_or_hybrid =
llama_model_is_recurrent(model_tgt) ||
llama_model_is_hybrid(model_tgt);
const bool needs_context_checkpoints =
params_base.n_ctx_checkpoints > 0 &&
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS ||
n_swa > 0);
slot.n_prompt_tokens_prefix =
needs_context_checkpoints && n_past > 0 ? n_past : -1;

// ref: https://github.com/ggml-org/llama.cpp/pull/24110
const bool has_new_tokens = (n_past < slot.task->n_tokens());
Expand Down Expand Up @@ -2946,9 +2992,19 @@ struct server_context_impl {
slot.prompt.checkpoints.rend(),
[&, func_name = __func__](const auto & cur) {
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d], pos_end = %d against prefix_end = %d, threshold = %d...\n", 12,
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, cur.pos_end, prefix_end, pos_min_thold);

// A checkpoint is reusable only if the entire saved state belongs to the
// common prompt prefix. Memory ranges alone are not sufficient for recurrent
// and hybrid models because they do not necessarily cover the full prefix.
if (cur.pos_end > prefix_end) {
return false;
}
// workaround for [TAG_CHECKPOINTS_FIX_POS_MIN]
if (is_recurrent_or_hybrid) {
return cur.pos_max < pos_next || cur.pos_min == 0;
}
if (cur.pos_max > pos_next) {
return false;
}
Expand All @@ -2965,7 +3021,7 @@ struct server_context_impl {

pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->pos_end, it->n_tokens, n_past, (float) it->size() / 1024 / 1024);
}

if (do_reset) {
Expand All @@ -2978,11 +3034,16 @@ struct server_context_impl {
}

{
// erase any checkpoints with pos_max > pos_next
// erase checkpoints that extend beyond the matching prefix. Dense models also
// invalidate checkpoints beyond the restored memory range, but recurrent/hybrid
// checkpoints can legitimately have a larger pos_max while still belonging to
// the same prompt prefix.
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_max > pos_next) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
const bool invalid_prefix = cur.pos_end > prefix_end;
const bool invalid_memory = !is_recurrent_or_hybrid && cur.pos_max > pos_next;
if (invalid_prefix || invalid_memory) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, pos_end = %d, n_tokens = %" PRId64 ", n_swa = %d, prefix_end = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.pos_end, cur.n_tokens, n_swa, prefix_end, pos_next, (float) cur.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
++it;
Expand Down Expand Up @@ -3065,6 +3126,25 @@ struct server_context_impl {
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS ||
n_swa > 0);

// Anchor a checkpoint at the already processed common-prefix boundary. This
// is done before adding more tokens because checkpoints capture the current
// decoded memory state, not the batch that is about to be decoded.
if (do_checkpoint &&
slot.n_prompt_tokens_prefix > 0 &&
slot.prompt.n_tokens() == slot.n_prompt_tokens_prefix) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
const bool is_spaced =
slot.prompt.checkpoints.empty() ||
slot.prompt.n_tokens() > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step;

if (pos_min >= 0 && is_spaced) {
create_checkpoint(slot, 0, pos_min, pos_max);
}

slot.n_prompt_tokens_prefix = -1;
}

bool has_mtmd = false;

// check if we should process the image
Expand All @@ -3077,6 +3157,20 @@ struct server_context_impl {
break;
}

// Preserve the reusable prefix before entering a media chunk. Checkpoints
// created after MTMD processing are not safe to restore.
if (do_checkpoint && !has_mtmd) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
const bool is_spaced =
slot.prompt.checkpoints.empty() ||
cur_token_idx > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step;

if (pos_min >= 0 && is_spaced) {
create_checkpoint(slot, 0, pos_min, pos_max);
}
}

// process the image
size_t n_tokens_out = 0;
int32_t res = slot.process_mtmd_chunk(cur_token_idx, n_tokens_out);
Expand Down Expand Up @@ -3162,6 +3256,7 @@ struct server_context_impl {
const auto n_tokens_cur = batch.n_tokens - n_tokens_prev;

const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;

// entire prompt has been processed
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
Expand All @@ -3186,10 +3281,6 @@ struct server_context_impl {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);

// checkpoints are created before the current batch is decoded, so
// their token position is the batch start rather than the prompt end
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;

{
const bool is_on_user =
n_before_user_known &&
Expand Down