Skip to content

Commit 8b27762

Browse files
Balazs Raczcopybara-github
authored andcommitted
Factors out a new cc_library :query from :gemma-lib.
Moves query-related structs/classes to gemma/query.h. This refactors PerQuery, AllQueries, and QBatch into a dedicated header file, gemma/query.h, and updates BUILD dependencies accordingly. PiperOrigin-RevId: 842676520
1 parent 73c3627 commit 8b27762

File tree

3 files changed

+202
-156
lines changed

3 files changed

+202
-156
lines changed

BUILD.bazel

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ cc_test(
141141
":kv_cache",
142142
":mat",
143143
":matmul",
144+
":query",
144145
":threading_context",
145146
":weights",
146147
"@googletest//:gtest_main", # buildcleaner: keep
@@ -444,9 +445,9 @@ cc_test(
444445
":gemma_lib",
445446
":mat",
446447
":ops",
448+
":query",
447449
":test_util",
448450
":threading_context",
449-
":zones",
450451
"@googletest//:gtest_main", # buildcleaner: keep
451452
"//compression:test_util",
452453
"//compression:types",
@@ -536,6 +537,17 @@ cc_test(
536537
],
537538
)
538539

540+
cc_library(
541+
name = "query",
542+
hdrs = ["gemma/query.h"],
543+
deps = [
544+
":basics",
545+
":gemma_args",
546+
":kv_cache",
547+
"@highway//:hwy",
548+
],
549+
)
550+
539551
cc_library(
540552
name = "gemma_args",
541553
hdrs = ["gemma/gemma_args.h"],
@@ -586,7 +598,7 @@ cc_library(
586598
":matmul_env",
587599
":model_store",
588600
":ops",
589-
":test_util",
601+
":query",
590602
":threading",
591603
":threading_context",
592604
":weights",
@@ -620,6 +632,7 @@ cc_test(
620632
":kv_cache",
621633
":mat",
622634
":matmul_env",
635+
":query",
623636
":test_util",
624637
":threading_context",
625638
":weights",

gemma/gemma.h

Lines changed: 1 addition & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "gemma/gemma_args.h"
2828
#include "gemma/kv_cache.h"
2929
#include "gemma/model_store.h"
30+
#include "gemma/query.h"
3031
#include "gemma/weights.h"
3132
#include "io/blob_store.h"
3233
#include "io/io.h" // Path
@@ -39,160 +40,6 @@
3940

4041
namespace gcpp {
4142

42-
struct PerQuery {
43-
PromptTokens prompt;
44-
45-
// Position in the KV cache: initially zero for the first turn, or when
46-
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
47-
size_t mutable_pos;
48-
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
49-
// which might differ from `prompt.size() - 1` for prefix-LM.
50-
size_t initial_pos;
51-
// Zero for causal attention, or the end of the prefix for prefix-LM style
52-
// attention in Paligemma.
53-
size_t prefix_end;
54-
55-
KVCachePtr kv_cache;
56-
57-
// Previous token generated for this query, or the last prompt token. Will be
58-
// fed into the next Transformer() call.
59-
int prev_token = 0;
60-
};
61-
62-
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
63-
struct AllQueries {
64-
AllQueries() = default;
65-
66-
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
67-
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
68-
const hwy::Span<KVCachePtr>& kv_caches) {
69-
per_query_.reserve(kv_caches.size());
70-
for (size_t i = 0; i < kv_caches.size(); ++i) {
71-
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
72-
per_query_.push_back(PerQuery{
73-
.prompt = prompt,
74-
.mutable_pos = pos,
75-
.initial_pos = pos,
76-
.prefix_end = prefix_end,
77-
.kv_cache = kv_caches[i],
78-
});
79-
}
80-
}
81-
82-
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
83-
const hwy::Span<KVCache>& kv_caches)
84-
: AllQueries(prompt, pos, prefix_end,
85-
hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches))) {}
86-
87-
// Batch of queries with initial position set to zero. Causal attention
88-
// is requested via empty or all-zero `prefix_end`.
89-
AllQueries(
90-
const hwy::Span<const PromptTokens>& prompts,
91-
const hwy::Span<KVCachePtr>& kv_caches,
92-
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
93-
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
94-
per_query_.reserve(prompts.size());
95-
for (size_t i = 0; i < prompts.size(); ++i) {
96-
HWY_ASSERT(kv_caches.size() == 0 ||
97-
kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
98-
per_query_.push_back(PerQuery{
99-
.prompt = prompts[i],
100-
.mutable_pos = 0,
101-
.initial_pos = 0,
102-
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
103-
.kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i],
104-
});
105-
}
106-
}
107-
108-
AllQueries(
109-
const hwy::Span<const PromptTokens>& prompts,
110-
const hwy::Span<KVCache>& kv_caches,
111-
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>())
112-
: AllQueries(prompts, hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches)),
113-
prefix_end) {}
114-
115-
void Reserve(size_t size) { per_query_.reserve(size); }
116-
void Append(const PerQuery& query) { per_query_.push_back(query); }
117-
118-
size_t NumQueries() const { return per_query_.size(); }
119-
120-
PerQuery& operator[](size_t query_idx) {
121-
HWY_DASSERT(query_idx < NumQueries());
122-
return per_query_[query_idx];
123-
}
124-
const PerQuery& operator[](size_t query_idx) const {
125-
HWY_DASSERT(query_idx < NumQueries());
126-
return per_query_[query_idx];
127-
}
128-
129-
private:
130-
std::vector<PerQuery> per_query_;
131-
};
132-
133-
// View into AllQueries: either a batch of queries, or a single query for use
134-
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
135-
// reference to AllQueries.
136-
class QBatch {
137-
public:
138-
QBatch(size_t start, size_t max_size, AllQueries& queries)
139-
: start_(start),
140-
max_size_(max_size),
141-
queries_(queries),
142-
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
143-
HWY_ASSERT(max_size_ <= kMaxBatchSize);
144-
HWY_DASSERT(size_ != 0);
145-
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
146-
for (int i = 0; i < size_; ++i) {
147-
query_idx_.push_back(start_ + i);
148-
}
149-
}
150-
151-
// Returns a single-query view starting at `qi` relative to this batch.
152-
QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); }
153-
154-
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
155-
size_t Size() const { return size_; }
156-
157-
// Returns index for use with `AllQueries` and `BatchStreamToken`.
158-
size_t QueryIdx(size_t qi) const {
159-
HWY_DASSERT(qi < size_);
160-
return query_idx_[qi];
161-
}
162-
163-
// Accessor functions to bridge the previous SoA and current AoS layout.
164-
const PromptTokens& Prompt(size_t qi) const {
165-
return queries_[QueryIdx(qi)].prompt;
166-
}
167-
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
168-
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
169-
size_t InitialPos(size_t qi) const {
170-
return queries_[QueryIdx(qi)].initial_pos;
171-
}
172-
size_t PrefixEnd(size_t qi) const {
173-
return queries_[QueryIdx(qi)].prefix_end;
174-
}
175-
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
176-
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
177-
178-
// let query_idx_[to] point to the from in the queries_; this is only used if
179-
// the slot in the QBatch is less than the number of queries.
180-
void Insert(size_t from, size_t to) {
181-
if (from == to) return;
182-
HWY_ASSERT(!queries_[from].kv_cache.IsEmpty());
183-
HWY_ASSERT(queries_[to].kv_cache.IsEmpty());
184-
// Conceptually, insert from.query to location to.
185-
query_idx_[to] = from;
186-
}
187-
188-
protected:
189-
size_t start_;
190-
size_t max_size_;
191-
AllQueries& queries_;
192-
std::vector<size_t> query_idx_;
193-
size_t size_;
194-
};
195-
19643
// Used for continuous batching.
19744
class ContinuousQBatch : public QBatch {
19845
public:

0 commit comments

Comments
 (0)