Skip to content

llama: add llama_batch_ext#24669

Draft
ngxson wants to merge 2 commits into
ggml-org:masterfrom
ngxson:xsn/llama_batch_ext
Draft

llama: add llama_batch_ext#24669
ngxson wants to merge 2 commits into
ggml-org:masterfrom
ngxson:xsn/llama_batch_ext

Conversation

@ngxson

@ngxson ngxson commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Overview

Supersede #11875

Status: early WIP, for discussion only

Additional information

Demo usage:

auto * batch = llama_batch_ext_init(ctx);
int32_t last_idx = 0;
for (auto token_id : input_tokens) {
  llama_batch_token t{
    token_id,
    nullptr, // embd
    nullptr, // embd_nextn
    nullptr, // pos
    0 // seq_id
  };
  last_idx = llama_batch_ext_add_token(batch, t);
}

llama_batch_ext_set_output(batch, last_idx, true);
llama_process(ctx, LLAMA_PROCESS_TYPE_DECODE, batch); // process the prompt

while (true) {
  float * logits = llama_batch_ext_get_logits(batch, last_idx);

  // Sample the next token from the logits
  // optionally check for stop condition
  llama_token next_token_id = sample_next_token(logits);

  // Process the sampled token
  llama_batch_token t{
    next_token_id,
    nullptr, // embd
    nullptr, // embd_nextn
    nullptr, // pos
    0 // seq_id
  };
  int32_t idx = llama_batch_ext_add_token(batch, t);
  llama_batch_ext_set_output(batch, idx, true);
  int32_t result = llama_process(ctx, LLAMA_PROCESS_TYPE_DECODE, batch); // process the next token
  if (result != 0) {
    break; // stop if there is an error or end of sequence
  }
}

Requirements

@ngxson ngxson requested a review from ggerganov as a code owner June 15, 2026 21:15
@ngxson ngxson marked this pull request as draft June 15, 2026 21:15

@ggerganov ggerganov left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep it simpler, we can do the output-related logic in a next PR. I.e. in the first PR, we just introduce the llama_batch_ext and use it to pass the inputs, but we leave the llama_context to handle the output buffers and embedding extractions as it is. Then in the next PR, we will move all the output logic to the batch.

Comment thread include/llama.h
Comment on lines +996 to +1001
// Set output = true for the last added token in the batch
// Returns the batch index (>= 0)
LLAMA_API bool llama_batch_ext_set_output(
struct llama_batch_ext * batch,
int32_t idx,
bool output_last);

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is incorrect here - it's not the last added token, but the idx token.

Comment thread include/llama.h
float * embd_nextn; // used by nextn layers
llama_pos * pos; // if nullptr, the position will be automatically assigned
// for M-RoPE models, embedding tokens must have multiple positions per token; text token only requires one single position per token
llama_seq_id seq_id;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still want to support multiple sequence ids per token.

@ngxson ngxson Jun 16, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm tbh I don't quite like passing pointer-to-pointer llama_seq_id **, as it makes the caller do more works. so I'm wondering if we should redesign it to get rid of struct llama_batch_token. my idea now is to have 2 categories of calls:

_add call that returns batch index:

  • llama_batch_ext_add_token --> add by token ID
  • llama_batch_ext_add_embd --> add by embeddings

then an array of _set that adds more info to the returned batch index:

  • llama_batch_ext_set_embd_nextn(int32_t idx, float * embd)
  • llama_batch_ext_set_seq_id(int32_t idx, llama_seq_id * seq_id, size_t n_seq) --> can set to multiple sequences
  • llama_batch_ext_set_pos
  • llama_batch_ext_set_output

also, do you think _set_output should be a boolean, or it should be a bit field, for example LLAMA_OUTPUT_NEXTN | LLAMA_OUTPUT_EMBD ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I'm wondering if we should redesign it to get rid of struct llama_batch_token

Yes, API based on the entry index seems quite generic and clean.

llama_batch_ext_add_token --> add by token ID
llama_batch_ext_add_embd --> add by embeddings

I think you can simplify by having single llama_batch_ext_add instead of differentiating token/embd.

also, do you think _set_output should be a boolean, or it should be a bit field, for example LLAMA_OUTPUT_NEXTN | LLAMA_OUTPUT_EMBD ?

I think we probably need per-entry llama_batch_ext_set_output(batch, idx, value);. And then the contents of the outputs likely not have to be per-entry, but for the entire batch:

llama_batch_ext_output_embd      (batch, value); 
llama_batch_ext_output_embd_nextn(batch, value, masked);
llama_batch_ext_output_layer_inp (batch, value); 
...

But for now these can remain controlled by the llama_context for now because the llama_context currently will own the output buffers, not the batch.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can simplify by having single llama_batch_ext_add instead of differentiating token/embd.

actually there will be 3 choices for that:

  1. llama_batch_ext_add() that simply returns an idx, and need a separated _set_token(id) or _set_embd(float * embd)
  2. llama_batch_ext_add(id), if id == LLAMA_TOKEN_NULL then _set_embd(float *) is required
  3. llama_batch_ext_add(id, float * embd) where either one of two can be set

however, since each input entry in the batch requires at least token ID or token embd to be consider "valid", I think having explicit _add_token and _add_embd will be a better design overall. WDYT ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have all three:

int32_t llama_batch_ext_add      (llama_batch_ext * batch);
int32_t llama_batch_ext_add_token(llama_batch_ext * batch, llama_token id);
int32_t llama_batch_ext_add_embd (llama_batch_ext * batch, llama_embd embd);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants