1
1
using System ;
2
2
using System . Collections . Generic ;
3
3
using System . Diagnostics ;
4
- using System . Linq ;
5
4
using System . Threading ;
6
5
using System . Threading . Tasks ;
7
6
using LLama . Abstractions ;
@@ -16,7 +15,12 @@ public sealed class BatchedExecutor
16
15
: IDisposable
17
16
{
18
17
private int _nextSequenceId ;
19
- private readonly List < IBatch > _batchQueue = [ ] ;
18
+ private readonly List < IBatch > _batchQueue = [ ] ;
19
+ private int _batchQueueHead ;
20
+ private int _batchedTokenCount ;
21
+ private bool _batchedTokenCountDirty = true ;
22
+ // Skip compacting the queue until this many processed batches accumulate at the front.
23
+ private const int CleanupThreshold = 16 ;
20
24
21
25
/// <summary>
22
26
/// Set to 1 using interlocked exchange while inference is running
@@ -42,12 +46,27 @@ public sealed class BatchedExecutor
42
46
/// <summary>
43
47
/// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
44
48
/// </summary>
45
- public int BatchedTokenCount => _batchQueue . Sum ( a => a . ItemCount ) ;
49
+ public int BatchedTokenCount
50
+ {
51
+ get
52
+ {
53
+ if ( _batchedTokenCountDirty )
54
+ {
55
+ var total = 0 ;
56
+ for ( var i = _batchQueueHead ; i < _batchQueue . Count ; i ++ )
57
+ total += _batchQueue [ i ] . ItemCount ;
58
+ _batchedTokenCount = total ;
59
+ _batchedTokenCountDirty = false ;
60
+ }
61
+
62
+ return _batchedTokenCount ;
63
+ }
64
+ }
46
65
47
66
/// <summary>
48
67
/// Number of batches in the queue, waiting for <see cref="Infer"/> to be called
49
68
/// </summary>
50
- public int BatchQueueCount => _batchQueue . Count ;
69
+ public int BatchQueueCount => _batchQueue . Count - _batchQueueHead ;
51
70
52
71
/// <summary>
53
72
/// Check if this executor has been disposed.
@@ -147,12 +166,13 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
147
166
// again after the issue has been fixed (e.g. some KV cache space has been freed) to retry this operation.
148
167
if ( status != DecodeResult . Ok )
149
168
{
150
- _batchQueue . Insert ( 0 , next ) ;
169
+ RequeueFront ( next ) ;
151
170
return status ;
152
171
}
153
172
154
173
// Everything was ok, advance the epoch
155
174
Epoch ++ ;
175
+ CleanupQueue ( ) ;
156
176
157
177
return status ;
158
178
}
@@ -166,13 +186,45 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
166
186
167
187
IBatch ? GetNextBatch ( )
168
188
{
169
- if ( _batchQueue . Count == 0 )
189
+ if ( _batchQueueHead >= _batchQueue . Count )
190
+ {
191
+ _batchQueue . Clear ( ) ;
192
+ _batchQueueHead = 0 ;
170
193
return null ;
171
-
172
- var nextBatch = _batchQueue [ 0 ] ;
173
- _batchQueue . RemoveAt ( 0 ) ;
194
+ }
195
+
196
+ var nextBatch = _batchQueue [ _batchQueueHead ] ;
197
+ _batchQueueHead ++ ;
198
+ _batchedTokenCountDirty = true ;
174
199
return nextBatch ;
175
200
}
201
+
202
+ void RequeueFront ( IBatch batch )
203
+ {
204
+ Debug . Assert ( _batchQueueHead > 0 , "Cannot requeue batch when queue head is at zero." ) ;
205
+ _batchQueue [ -- _batchQueueHead ] = batch ;
206
+ _batchedTokenCountDirty = true ;
207
+ }
208
+
209
+ // Remove batches that have already been consumed so the head index does not grow without bound.
210
+ void CleanupQueue ( )
211
+ {
212
+ if ( _batchQueueHead == 0 )
213
+ return ;
214
+
215
+ if ( _batchQueueHead >= _batchQueue . Count )
216
+ {
217
+ _batchQueue . Clear ( ) ;
218
+ _batchQueueHead = 0 ;
219
+ return ;
220
+ }
221
+
222
+ if ( _batchQueueHead > CleanupThreshold && _batchQueueHead > _batchQueue . Count / 2 )
223
+ {
224
+ _batchQueue . RemoveRange ( 0 , _batchQueueHead ) ;
225
+ _batchQueueHead = 0 ;
226
+ }
227
+ }
176
228
}
177
229
178
230
/// <inheritdoc />
@@ -202,7 +254,7 @@ internal LLamaSeqId GetNextSequenceId()
202
254
throw new ArgumentOutOfRangeException ( nameof ( minCapacity ) , $ "Request batch capacity must be less than or equal to BatchSize ({ Context . BatchSize } )") ;
203
255
204
256
// Find a batch with space for at least minCapacity tokens
205
- for ( var i = 0 ; i < _batchQueue . Count ; i ++ )
257
+ for ( var i = _batchQueueHead ; i < _batchQueue . Count ; i ++ )
206
258
{
207
259
var item = _batchQueue [ i ] ;
208
260
if ( item is not TokenBatch { Batch : var batch } )
@@ -213,13 +265,17 @@ internal LLamaSeqId GetNextSequenceId()
213
265
continue ;
214
266
215
267
if ( batch . TokenCount < Context . BatchSize )
216
- return ( batch , Epoch + ( uint ) ( i + 1 ) * 2 ) ;
268
+ {
269
+ _batchedTokenCountDirty = true ;
270
+ return ( batch , Epoch + ( uint ) ( i - _batchQueueHead + 1 ) * 2 ) ;
271
+ }
217
272
}
218
273
219
274
// Add a new batch to the end of the queue
220
275
var end = new LLamaBatch ( ) ;
221
276
_batchQueue . Add ( new TokenBatch ( end ) ) ;
222
- return ( end , Epoch + ( uint ) _batchQueue . Count * 2 ) ;
277
+ _batchedTokenCountDirty = true ;
278
+ return ( end , Epoch + ( uint ) ( _batchQueue . Count - _batchQueueHead ) * 2 ) ;
223
279
}
224
280
225
281
/// <summary>
@@ -234,7 +290,7 @@ internal LLamaSeqId GetNextSequenceId()
234
290
throw new ArgumentOutOfRangeException ( nameof ( minCapacity ) , $ "Request batch capacity must be less than or equal to BatchSize ({ Context . BatchSize } )") ;
235
291
236
292
// Find a batch with space for at least minCapacity embeddings
237
- for ( var i = 0 ; i < _batchQueue . Count ; i ++ )
293
+ for ( var i = _batchQueueHead ; i < _batchQueue . Count ; i ++ )
238
294
{
239
295
var item = _batchQueue [ i ] ;
240
296
if ( item is not EmbeddingBatch { Batch : var batch } )
@@ -245,13 +301,17 @@ internal LLamaSeqId GetNextSequenceId()
245
301
continue ;
246
302
247
303
if ( batch . EmbeddingsCount < Context . BatchSize )
248
- return ( batch , Epoch + ( uint ) ( i + 1 ) * 2 ) ;
304
+ {
305
+ _batchedTokenCountDirty = true ;
306
+ return ( batch , Epoch + ( uint ) ( i - _batchQueueHead + 1 ) * 2 ) ;
307
+ }
249
308
}
250
309
251
310
// Add a new batch to the end of the queue
252
311
var end = new LLamaBatchEmbeddings ( Context . EmbeddingSize ) ;
253
312
_batchQueue . Add ( new EmbeddingBatch ( end ) ) ;
254
- return ( end , Epoch + ( uint ) _batchQueue . Count * 2 ) ;
313
+ _batchedTokenCountDirty = true ;
314
+ return ( end , Epoch + ( uint ) ( _batchQueue . Count - _batchQueueHead ) * 2 ) ;
255
315
}
256
316
257
317
#region batches
@@ -286,4 +346,4 @@ public Task<DecodeResult> DecodeAsync(LLamaContext ctx, CancellationToken token)
286
346
}
287
347
}
288
348
#endregion
289
- }
349
+ }
0 commit comments