|
27 | 27 | #include "gemma/gemma_args.h" |
28 | 28 | #include "gemma/kv_cache.h" |
29 | 29 | #include "gemma/model_store.h" |
| 30 | +#include "gemma/query.h" |
30 | 31 | #include "gemma/weights.h" |
31 | 32 | #include "io/blob_store.h" |
32 | 33 | #include "io/io.h" // Path |
|
39 | 40 |
|
40 | 41 | namespace gcpp { |
41 | 42 |
|
42 | | -struct PerQuery { |
43 | | - PromptTokens prompt; |
44 | | - |
45 | | - // Position in the KV cache: initially zero for the first turn, or when |
46 | | - // multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`. |
47 | | - size_t mutable_pos; |
48 | | - // Allows computing the last prefill token as `mutable_pos - initial_pos`, |
49 | | - // which might differ from `prompt.size() - 1` for prefix-LM. |
50 | | - size_t initial_pos; |
51 | | - // Zero for causal attention, or the end of the prefix for prefix-LM style |
52 | | - // attention in Paligemma. |
53 | | - size_t prefix_end; |
54 | | - |
55 | | - KVCachePtr kv_cache; |
56 | | - |
57 | | - // Previous token generated for this query, or the last prompt token. Will be |
58 | | - // fed into the next Transformer() call. |
59 | | - int prev_token = 0; |
60 | | -}; |
61 | | - |
62 | | -// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`. |
63 | | -struct AllQueries { |
64 | | - AllQueries() = default; |
65 | | - |
66 | | - // For `GenerateSingleT`: same prompt/pos, replicated for each KV cache. |
67 | | - AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, |
68 | | - const hwy::Span<KVCachePtr>& kv_caches) { |
69 | | - per_query_.reserve(kv_caches.size()); |
70 | | - for (size_t i = 0; i < kv_caches.size(); ++i) { |
71 | | - HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); |
72 | | - per_query_.push_back(PerQuery{ |
73 | | - .prompt = prompt, |
74 | | - .mutable_pos = pos, |
75 | | - .initial_pos = pos, |
76 | | - .prefix_end = prefix_end, |
77 | | - .kv_cache = kv_caches[i], |
78 | | - }); |
79 | | - } |
80 | | - } |
81 | | - |
82 | | - AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, |
83 | | - const hwy::Span<KVCache>& kv_caches) |
84 | | - : AllQueries(prompt, pos, prefix_end, |
85 | | - hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches))) {} |
86 | | - |
87 | | - // Batch of queries with initial position set to zero. Causal attention |
88 | | - // is requested via empty or all-zero `prefix_end`. |
89 | | - AllQueries( |
90 | | - const hwy::Span<const PromptTokens>& prompts, |
91 | | - const hwy::Span<KVCachePtr>& kv_caches, |
92 | | - const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) { |
93 | | - HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); |
94 | | - per_query_.reserve(prompts.size()); |
95 | | - for (size_t i = 0; i < prompts.size(); ++i) { |
96 | | - HWY_ASSERT(kv_caches.size() == 0 || |
97 | | - kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); |
98 | | - per_query_.push_back(PerQuery{ |
99 | | - .prompt = prompts[i], |
100 | | - .mutable_pos = 0, |
101 | | - .initial_pos = 0, |
102 | | - .prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i], |
103 | | - .kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i], |
104 | | - }); |
105 | | - } |
106 | | - } |
107 | | - |
108 | | - AllQueries( |
109 | | - const hwy::Span<const PromptTokens>& prompts, |
110 | | - const hwy::Span<KVCache>& kv_caches, |
111 | | - const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) |
112 | | - : AllQueries(prompts, hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches)), |
113 | | - prefix_end) {} |
114 | | - |
115 | | - void Reserve(size_t size) { per_query_.reserve(size); } |
116 | | - void Append(const PerQuery& query) { per_query_.push_back(query); } |
117 | | - |
118 | | - size_t NumQueries() const { return per_query_.size(); } |
119 | | - |
120 | | - PerQuery& operator[](size_t query_idx) { |
121 | | - HWY_DASSERT(query_idx < NumQueries()); |
122 | | - return per_query_[query_idx]; |
123 | | - } |
124 | | - const PerQuery& operator[](size_t query_idx) const { |
125 | | - HWY_DASSERT(query_idx < NumQueries()); |
126 | | - return per_query_[query_idx]; |
127 | | - } |
128 | | - |
129 | | - private: |
130 | | - std::vector<PerQuery> per_query_; |
131 | | -}; |
132 | | - |
133 | | -// View into AllQueries: either a batch of queries, or a single query for use |
134 | | -// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a |
135 | | -// reference to AllQueries. |
136 | | -class QBatch { |
137 | | - public: |
138 | | - QBatch(size_t start, size_t max_size, AllQueries& queries) |
139 | | - : start_(start), |
140 | | - max_size_(max_size), |
141 | | - queries_(queries), |
142 | | - size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) { |
143 | | - HWY_ASSERT(max_size_ <= kMaxBatchSize); |
144 | | - HWY_DASSERT(size_ != 0); |
145 | | - HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); |
146 | | - for (int i = 0; i < size_; ++i) { |
147 | | - query_idx_.push_back(start_ + i); |
148 | | - } |
149 | | - } |
150 | | - |
151 | | - // Returns a single-query view starting at `qi` relative to this batch. |
152 | | - QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); } |
153 | | - |
154 | | - // How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`. |
155 | | - size_t Size() const { return size_; } |
156 | | - |
157 | | - // Returns index for use with `AllQueries` and `BatchStreamToken`. |
158 | | - size_t QueryIdx(size_t qi) const { |
159 | | - HWY_DASSERT(qi < size_); |
160 | | - return query_idx_[qi]; |
161 | | - } |
162 | | - |
163 | | - // Accessor functions to bridge the previous SoA and current AoS layout. |
164 | | - const PromptTokens& Prompt(size_t qi) const { |
165 | | - return queries_[QueryIdx(qi)].prompt; |
166 | | - } |
167 | | - size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; } |
168 | | - size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; } |
169 | | - size_t InitialPos(size_t qi) const { |
170 | | - return queries_[QueryIdx(qi)].initial_pos; |
171 | | - } |
172 | | - size_t PrefixEnd(size_t qi) const { |
173 | | - return queries_[QueryIdx(qi)].prefix_end; |
174 | | - } |
175 | | - KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } |
176 | | - int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } |
177 | | - |
178 | | - // let query_idx_[to] point to the from in the queries_; this is only used if |
179 | | - // the slot in the QBatch is less than the number of queries. |
180 | | - void Insert(size_t from, size_t to) { |
181 | | - if (from == to) return; |
182 | | - HWY_ASSERT(!queries_[from].kv_cache.IsEmpty()); |
183 | | - HWY_ASSERT(queries_[to].kv_cache.IsEmpty()); |
184 | | - // Conceptually, insert from.query to location to. |
185 | | - query_idx_[to] = from; |
186 | | - } |
187 | | - |
188 | | - protected: |
189 | | - size_t start_; |
190 | | - size_t max_size_; |
191 | | - AllQueries& queries_; |
192 | | - std::vector<size_t> query_idx_; |
193 | | - size_t size_; |
194 | | -}; |
195 | | - |
196 | 43 | // Used for continuous batching. |
197 | 44 | class ContinuousQBatch : public QBatch { |
198 | 45 | public: |
|
0 commit comments