Skip to content

(superseded) llama : add llama_batch_ext#11875

Closed
ngxson wants to merge 61 commits into
ggml-org:masterfrom
ngxson:xsn/private_batch_api
Closed

(superseded) llama : add llama_batch_ext#11875
ngxson wants to merge 61 commits into
ggml-org:masterfrom
ngxson:xsn/private_batch_api

Conversation

@ngxson

@ngxson ngxson commented Feb 14, 2025

Copy link
Copy Markdown
Collaborator

Ref comment: #11292 (comment)

Closes #10381

Migration patterns:

llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
// becomes:
llama_batch_ext_ptr batch = llama_batch_ext_ptr(llama_batch_ext_init(n_kv_max, 1));


common_batch_add(batch, tokens[i], pos, { 0 }, false);
// becomes:
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens[i], pos, &seq_id, 1, false);


llama_decode(lctx, llama_batch_get_one(tokens.data(), std::min(tokens.size(), (size_t) params.n_batch)));
// becomes:
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
llama_decode_ext(lctx, batch.get());


llama_decode(ctx, batch);
// becomes:
llama_decode_ext(ctx, batch.get());

Current status:

  • This PR currently contains the first proposal of public API that allows hiding llama_batch from public API --> To be discussed
  • Only llama-server works for now
  • TODO: the members of llama_batch can be migrated to cpp types

@ngxson

ngxson commented Feb 14, 2025

Copy link
Copy Markdown
Collaborator Author

@ggerganov Would you mind having a look on this initial proposal? Thank you!

Comment thread src/llama-batch.h Outdated
Comment thread include/llama.h Outdated
Comment on lines +266 to +272
struct llama_batch_ext_token_info {
llama_token token;
llama_pos pos;
int32_t n_seq_id;
llama_seq_id * seq_id;
int8_t logits;
};

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be very future-proof. Mixed-modality batches would have tokens, embeddings and tensors mixed together in the same batch. So calling llama_batch_ext_get_token_info(batch, i); is not always well-defined because it might not be a token at position i.

Maybe we can postpone this "token_info" API. I think all usages in the examples that require to read back info from the batch can be implemented in the example code without relying on the API. This way we can focus only on implementing only the API for creating batches and adding data to them. Later on when we have a better idea of the implementation, we can add a helper API to get info back from the batches.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree. Furthermore, this API requires doing a copy, so it won't be the best for performance. It's better to remove this API for now.

I think all usages in the examples that require to read back info from the batch can be implemented in the example code without relying on the API.

This kind of logic is currently being used inside llama-server, not sure it appears on any other examples. I think I can make a thin wrapper for llama_batch_ext inside the example code. Feel free to tell me if you have a better idea.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is removed in 1d6ba97 , a new server_batch wrapper is added to manage token logits placement in the batch

Comment thread common/common.cpp Outdated
Comment thread include/llama.h Outdated
@ngxson

ngxson commented Mar 1, 2025

Copy link
Copy Markdown
Collaborator Author

OK so I've been able to apply this to various example (not all of them). Would be nice if you can have a quick look @ggerganov before I migrate the rest.

One thing to note, the loop check over tokens in batch (discussed in #11875 (comment)) is used by both server.cpp and embeddings.cpp, so my solution was to create a thin wrapper called common_batch. Looks a bit messy for now, so I'm wondering if in the future we can have a llama_get_embeddings_ext or something that can make this easier.

@ggerganov

Copy link
Copy Markdown
Member

The common_batch is ok for now.

Looks a bit messy for now, so I'm wondering if in the future we can have a llama_get_embeddings_ext or something that can make this easier.

It seems we rather need something to query the batch, no? How do you imagine llama_get_embeddings_ext to work?

I was thinking something like:

struct llama_batch_ext_part;

llama_batch_ext_part * part = llama_batch_ext_get_part(batch, i);
if (llama_batch_ext_part_is_token(part)) {
    llama_token id = llama_batch_ext_part_get_id(part);
    ... get token id, sequence id, etc. ...
}

But since I'm not 100% about all the details yet related to multi-modal batches, I think it is better to postpone this API for later, and handle the batch information in the user code for now.

@ngxson

ngxson commented Mar 3, 2025

Copy link
Copy Markdown
Collaborator Author

How do you imagine llama_get_embeddings_ext to work?

I don't have a clear idea yet, but I'm thinking as a developer using libllama in the their program: Whenever I add a token to the batch, in case of text token I need to know:

  • The token (token ID in case of text)
  • The pos
  • The seq_id

So when I retrieve back the logits/embeddings, I would imagine that the get_embeddings function will have one of these 2 signatures:

  • get_embeddings(seq_id) ==> we already had llama_get_embeddings_seq
  • get_embeddings(seq_id, pos) ==> we currently need to read back the tokens from batch

It seems we rather need something to query the batch, no?

Yes we can, and this will be quite similar to my point above. I'm thinking about these 2 options:

  • Having something like llama_batch_ext_query(seq_id, pos) that returns the output_id of the token. This can then be used with llama_get_embeddings_ith(output_id)
  • Or, explicitly has llama_batch_ext_set_output(...) that returns the output_id. That means the logits param will be removed from llama_batch_ext_add_text
  • (Edit) Or, another option, llama_batch_ext_add_text can return the output_id if logits is set to true

@ggerganov

Copy link
Copy Markdown
Member

Yes we can, and this will be quite similar to my point above. I'm thinking about these 2 options:

Having something like llama_batch_ext_query(seq_id, pos) that returns the output_id of the token. This can then be used with llama_get_embeddings_ith(output_id)
Or, explicitly has llama_batch_ext_set_output(...) that returns the output_id. That means the logits param will be removed from llama_batch_ext_add_text
(Edit) Or, another option, llama_batch_ext_add_text can return the output_id if logits is set to true

Hm, yes. The llama_batch_ext_set_output() idea sounds good.

Btw, this makes me wonder if we should actually move the output buffers for logits and the embeddings to be owned by the llama_batch_ext (currently these buffers are owned by the llama_context and are shared by all batches)?

@ggerganov

Copy link
Copy Markdown
Member

Now that #12181 has been merged, it should be a good time to get this merged too.

@ngxson

ngxson commented Mar 13, 2025

Copy link
Copy Markdown
Collaborator Author

Yes thanks for the heads up, I'll focus on finishing this today & tomorrow

@ngxson

ngxson commented Mar 13, 2025

Copy link
Copy Markdown
Collaborator Author

Btw, this makes me wonder if we should actually move the output buffers for logits and the embeddings to be owned by the llama_batch_ext (currently these buffers are owned by the llama_context and are shared by all batches)?

If the output logits and embeddings are staying are float * or std::vector<float> then yes, I think it will be better to move them to llama_batch_ext (and can be done in a follow-up PR)

Comment thread src/llama-batch.cpp Outdated
struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max);
struct llama_batch_ext * llama_batch_ext_init(struct llama_context * ctx) {
return llama_batch_ext_init_impl(llama_n_batch(ctx), 0, llama_n_seq_max(ctx));

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for now this is a good solution. Later, we will be resizing dynamically and not need the llama_n_batch().

@ggerganov

Copy link
Copy Markdown
Member

But before implementing this, I just want to check with you if my direction still looks ok.

I think this looks good.

@ngxson

ngxson commented Mar 25, 2025

Copy link
Copy Markdown
Collaborator Author

I implemented the fix, tested qwen2vl-cli and it seems to work with text. Can't test with image for now since it's also broken on master

The command is: llama-qwen2vl-cli -m ../models/Qwen2-VL-7B-Instruct-Q4_K_M.gguf --mmproj ../models/mmproj-Qwen2-VL-7B-Instruct-f16.gguf --image ../models/bliss.png -p "what do you see?"


A bit confused about this, I found this nice illustration on qwen2vl model page on HF, which shows that qwen only use 3 pos per token. It's also confirmed by the config.json file. However, I'm not sure why in llama.cpp we use up to 4 pos.

image

@ggerganov

Copy link
Copy Markdown
Member

bit confused about this, I found this nice illustration on qwen2vl model page on HF, which shows that qwen only use 3 pos per token. It's also confirmed by the config.json file. However, I'm not sure why in llama.cpp we use up to 4 pos.

cc @HimariO

Comment thread examples/server/server.cpp Outdated
@@ -1963,7 +1963,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);

// only a single seq_id per token is needed

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is obsolete.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think I removed it in one of the commits above, we don't need n_batch anymore so I removed this whole code block

@HimariO

HimariO commented Mar 27, 2025

Copy link
Copy Markdown
Contributor

A bit confused about this, I found this nice illustration on qwen2vl model page on HF, which shows that qwen only use 3 pos per token. It's also confirmed by the config.json file. However, I'm not sure why in llama.cpp we use up to 4 pos.

The fourth position ID is mainly for future-proofing, in case a newer model that takes 3D/depth input(like SpatialLM) is added. Currently, both Qwen2 & 2.5VL only use 3 position ID per token.

@ggerganov

Copy link
Copy Markdown
Member

@ngxson We discussed with @slaren these changes and he raised a good point that the batch API does not need to explicitly pass token positions. These can be inferred from the KV cache.

Since not having to pass explicit token positions would simplify the batch API, it's a good idea to take it into account when redesigning it. So I will try to do some KV cache refactoring to support this. When I'm ready, I will come back to this PR and update it respectively.

@ngxson

ngxson commented Apr 1, 2025

Copy link
Copy Markdown
Collaborator Author

Thanks for the head up! Yes for text batch it would be nice if the position can be inferred from KV.

Please also note that, for multimodal batch, we may also need the N-dimension position. For example, in the case of Qwen2VL we have a normal position + an additional 2D position for each image token. I think what we can do is that for the that the 2D coordinate position can be given by the user, and the "normal" position can still be inferred from KV.

So for example in this PR, the API for Qwen will accepts (2*n_tokens) positions

@ggerganov

Copy link
Copy Markdown
Member

I think the Qwen2VL image positions could be inferred from the token position:

for (int y = 0; y < ph; y++)
{
for (int x = 0; x < pw; x++)
{
int i = y * pw + x;
mrope_pos[i] = *st_pos_id;
mrope_pos[i + img_tokens] = *st_pos_id + y;
mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
mrope_pos[i + img_tokens * 3] = 0;
}
}
*st_pos_id += std::max(pw, ph);

So ideally, the user would not have to pass those as well.

@ngxson

ngxson commented Apr 1, 2025

Copy link
Copy Markdown
Collaborator Author

We still need to know the image size pw and ph in order to infer the 2D position. Ofc my goal for the future is to hide all of this inside the library. This is the idea in this comment where I suggest having a struct llama_mm_embd that contains the image embeddings tensor and the size of image pw, ph

In far future (or not that far?), struct llama_mm_embd can be used to also store everything that the text decoder wants to retrieve from multimodal encoder - that could be a series of image in case of video, audio tokens in case of audio input

@ggerganov

Copy link
Copy Markdown
Member

Yes, the multi-dim positions are complicating things. Not sure what is the best solution.

The problem is that many vision models nowadays also need to know the size of image when passing the embeddings from encoder to decoder.

Maybe with Gemma demonstrating that this is not necessary, new models won't need these complications and we don't have to add support at all. Are there other models other than Qwen2VL that need 2D positions?

When passing these embeddings to decoder, we need to "wrap" each slice between some tokens, then also "wrap" each row of 3 slices between another token:

This seems something that the user code can implement logic (similar how to bos/eos tokens are added). Which models use this pattern?

Anyway, the KV cache refactoring can be done before making a decision about how to handle the images, so we can re-discuss this after that.

@ngxson

ngxson commented Apr 1, 2025

Copy link
Copy Markdown
Collaborator Author

Just to clarify before writing my response, there are 2 reasons why the image size is needed:

  • For 2D positional embedding like Qwen, where each token has 3 positions (temporal, X, Y)
  • For slices layout --> currently use by a lot of models like MiniCPM-V, SmolVLM, llava 1.6. This technique use normal position and normal causal mask, but they add special tokens to identify rows/cols of slices

Maybe with Gemma demonstrating that this is not necessary, new models won't need these complications and we don't have to add support at all. Are there other models other than Qwen2VL that need 2D positions?

Hmm yeah that could be right. Because M-RoPE is invented by Qwen, I don't see anyone using it for now. Not sure if other models will adopt it in the future.

But please not that gemma 3 does not use slices. This makes working with gemma 3 vision easy, but the current problem with gemma 3 is that the image size is fixed. For bigger images, they need to rely on a technique called "pan and zoom", which essentially a prompting technique that allow the model to "ask" the runtime to zoom the image, then rerun the generation. This is obviously very inefficient.

Models like SmolVLM, MiniCPM-V, Qwen (and maybe many other) are already using slicing technique that I said earlier, so we definitely need to support this in the API.

This seems something that the user code can implement logic (similar how to bos/eos tokens are added). Which models use this pattern?

In fact, better to think is that is the "chat template" for image. While it can be implemented in user code, I think it's better to make it transparent from user POV, as this part is model-specific and even harder than normal chat template for text, this is not something user can easily debug.

In gemma3-cli, you can see that the <start_of_image> token is added from user code. But my goal for the vision API is to hide this behind an API.

@ngxson

ngxson commented Apr 4, 2025

Copy link
Copy Markdown
Collaborator Author

@ggerganov I've been working on audio input and output recently and I think the API proposed by this PR pretty much correspond to what I need (ofc except for the position, which will be nicer to be hidden from user code).

Having this PR merged could save me some efforts, and more importantly unblock my researches on multimodal API, so I'm wondering if anything I could do on my side to accelerate this a bit more? Thank you!


Also now that the position is hidden from user code, I think we should also somehow modify the llama_kv_* API to make sure that user don't accidentally leave a "hole" in the context. For example, if KV cache has 10 tokens and they delete tokens from position [2, 5), now there will be no API to fill in this "hole" unless we also remove all tokens [5, inf)

@ggerganov

Copy link
Copy Markdown
Member

My plan for the next steps was to refactor the llama_kv_cache_unified into 2 separate implementations - unified and recurrent (started some initial work in #12695). This will simplify the logic in llama_kv_cache and allow to design a better token position / sequence tracking in libllama which can then be used instead of manually passing positions with the batches. When this is ready, we can update this PR to not pass the positions explicitly.

Also now that the position is hidden from user code, I think we should also somehow modify the llama_kv_* API to make sure that user don't accidentally leave a "hole" in the context. For example, if KV cache has 10 tokens and they delete tokens from position [2, 5), now there will be no API to fill in this "hole" unless we also remove all tokens [5, inf)

The llama_kv_self_seq_ API would still require the user to keep track of the token positions / sequence lengths. But not sure if this can be improved.

@ngxson

ngxson commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator Author

@ggerganov I'm re-designing the API to take into account the MTP embeddings and llama_process(), coming up with this. WDYT?

I added llama_process_type to be either ENCODE/DECODE or something else in the future, but probably I misunderstood your intention.

    //
    // Extended batch API
    //

    struct llama_batch_ext;

    LLAMA_API struct llama_batch_ext * llama_batch_ext_init (struct llama_context * ctx);
    LLAMA_API void                     llama_batch_ext_free (struct llama_batch_ext * batch);
    LLAMA_API void                     llama_batch_ext_clear(struct llama_batch_ext * batch);

    struct llama_batch_token {
        llama_token     id;         // if id != LLAMA_TOKEN_NULL, embd must be nullptr
        float         * embd;       // if embd != nullptr, token id must be LLAMA_TOKEN_NULL
        float         * embd_nextn; // used by nextn layers
        llama_pos     * pos;        // if nullptr, the position will be automatically assigned
                                    // for M-RoPE models, embedding tokens must have multiple positions per token; text token only requires one single position per token
        llama_seq_id    seq_id;
        bool            output;
    };

    // Add an input token to the batch
    // Returns the batch index (>= 0)
    // On error:
    //    -1: batch is full
    //    -2: token is invalid (id == LLAMA_TOKEN_NULL or invalid embd)
    //    -3: invalid sequence id
    LLAMA_API int32_t llama_batch_ext_add_token(
            struct llama_batch_ext * batch,
                 llama_batch_token   token);

    // Set output = true for the last added token in the batch
    LLAMA_API int32_t llama_batch_ext_set_output_last(
            struct llama_batch_ext * batch);

    // Return values are the same as llama_decode()
    LLAMA_API int32_t llama_process(
                struct llama_context * ctx,
             enum llama_process_type   type,
              struct llama_batch_ext * batch);

Basic usage:

batch = llama_batch_ext_init(ctx);
for (auto token_id : input_tokens) {
  llama_batch_token t{
    token_id,
    nullptr, // embd
    nullptr, // embd_nextn
    nullptr, // pos
    0, // seq_id
    false // output
  };
}
llama_batch_ext_set_output_last(batch);
llama_process(ctx, LLAMA_PROCESS_TYPE_DECODE, batch);

@ggerganov ggerganov self-assigned this Jun 15, 2026
@ngxson ngxson changed the title llama : add llama_batch_ext (superseded) llama : add llama_batch_ext Jun 15, 2026
@ngxson

ngxson commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator Author

I moved this PR to #24669

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

android Issues specific to Android examples python python script changes server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Refactor: Allow adding both tokens and embeddings to llama_batch

3 participants