Skip to content

Commit 1de246c

Browse files
authored
Merge pull request #1262 from SignalRT/Optimization
Optimization - A queue with fixed storage size backed by a circular buffer
2 parents de00c15 + eac1f7a commit 1de246c

File tree

8 files changed

+230
-61
lines changed

8 files changed

+230
-61
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using System.Linq;
2+
using BenchmarkDotNet.Attributes;
3+
using BenchmarkDotNet.Engines;
4+
using BenchmarkDotNet.Jobs;
5+
using LLama.Common;
6+
7+
namespace LLama.Benchmark.Collections;
8+
9+
[SimpleJob(RunStrategy.Throughput, RuntimeMoniker.Net80)]
10+
[MemoryDiagnoser]
11+
[BenchmarkCategory("Collections", "FixedSizeQueue")]
12+
public class FixedSizeQueueBenchmark
13+
{
14+
[Params(32, 512, 4096)]
15+
public int Capacity { get; set; }
16+
17+
private int[] _values = Array.Empty<int>();
18+
19+
[GlobalSetup]
20+
public void Setup()
21+
{
22+
_values = Enumerable.Range(0, Capacity * 4).ToArray();
23+
}
24+
25+
[Benchmark]
26+
public int EnqueueWrap()
27+
{
28+
var queue = new FixedSizeQueue<int>(Capacity);
29+
foreach (var value in _values)
30+
queue.Enqueue(value);
31+
return queue.Count;
32+
}
33+
34+
[Benchmark]
35+
public int IterateTailSum()
36+
{
37+
var queue = new FixedSizeQueue<int>(Capacity);
38+
foreach (var value in _values)
39+
queue.Enqueue(value);
40+
41+
var sum = 0;
42+
foreach (var value in queue)
43+
sum += value;
44+
return sum;
45+
}
46+
}

LLama/AntipromptProcessor.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public sealed class AntipromptProcessor
1111
private int _longestAntiprompt;
1212
private readonly List<string> _antiprompts = new();
1313

14-
private string? _string;
14+
private string _buffer = string.Empty;
1515

1616

1717
/// <summary>
@@ -46,6 +46,8 @@ public void SetAntiprompts(IEnumerable<string> antiprompts)
4646
_longestAntiprompt = 0;
4747
foreach (var antiprompt in _antiprompts)
4848
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
49+
50+
_buffer = string.Empty;
4951
}
5052

5153
/// <summary>
@@ -55,21 +57,21 @@ public void SetAntiprompts(IEnumerable<string> antiprompts)
5557
/// <returns>true if the text buffer ends with any antiprompt</returns>
5658
public bool Add(string text)
5759
{
58-
_string += text;
60+
_buffer += text;
5961

6062
// When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length).
6163
// This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode
6264
// even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances!
6365
var maxLength = Math.Max(32, _longestAntiprompt * 4);
6466
var trimLength = Math.Max(16, _longestAntiprompt * 2);
65-
if (_string.Length > maxLength)
66-
_string = _string.Substring(_string.Length - trimLength);
67+
if (_buffer.Length > maxLength)
68+
_buffer = _buffer.Substring(_buffer.Length - trimLength);
6769

6870
foreach (var antiprompt in _antiprompts)
69-
if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture))
71+
if (_buffer.EndsWith(antiprompt, StringComparison.CurrentCulture))
7072
return true;
7173

7274
return false;
7375
}
7476
}
75-
}
77+
}

LLama/Batched/BatchedExecutor.cs

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Diagnostics;
4-
using System.Linq;
54
using System.Threading;
65
using System.Threading.Tasks;
76
using LLama.Abstractions;
@@ -16,7 +15,12 @@ public sealed class BatchedExecutor
1615
: IDisposable
1716
{
1817
private int _nextSequenceId;
19-
private readonly List<IBatch> _batchQueue = [ ];
18+
private readonly List<IBatch> _batchQueue = [];
19+
private int _batchQueueHead;
20+
private int _batchedTokenCount;
21+
private bool _batchedTokenCountDirty = true;
22+
// Skip compacting the queue until this many processed batches accumulate at the front.
23+
private const int CleanupThreshold = 16;
2024

2125
/// <summary>
2226
/// Set to 1 using interlocked exchange while inference is running
@@ -42,12 +46,27 @@ public sealed class BatchedExecutor
4246
/// <summary>
4347
/// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
4448
/// </summary>
45-
public int BatchedTokenCount => _batchQueue.Sum(a => a.ItemCount);
49+
public int BatchedTokenCount
50+
{
51+
get
52+
{
53+
if (_batchedTokenCountDirty)
54+
{
55+
var total = 0;
56+
for (var i = _batchQueueHead; i < _batchQueue.Count; i++)
57+
total += _batchQueue[i].ItemCount;
58+
_batchedTokenCount = total;
59+
_batchedTokenCountDirty = false;
60+
}
61+
62+
return _batchedTokenCount;
63+
}
64+
}
4665

4766
/// <summary>
4867
/// Number of batches in the queue, waiting for <see cref="Infer"/> to be called
4968
/// </summary>
50-
public int BatchQueueCount => _batchQueue.Count;
69+
public int BatchQueueCount => _batchQueue.Count - _batchQueueHead;
5170

5271
/// <summary>
5372
/// Check if this executor has been disposed.
@@ -147,12 +166,13 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
147166
// again after the issue has been fixed (e.g. some KV cache space has been freed) to retry this operation.
148167
if (status != DecodeResult.Ok)
149168
{
150-
_batchQueue.Insert(0, next);
169+
RequeueFront(next);
151170
return status;
152171
}
153172

154173
// Everything was ok, advance the epoch
155174
Epoch++;
175+
CleanupQueue();
156176

157177
return status;
158178
}
@@ -166,13 +186,45 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
166186

167187
IBatch? GetNextBatch()
168188
{
169-
if (_batchQueue.Count == 0)
189+
if (_batchQueueHead >= _batchQueue.Count)
190+
{
191+
_batchQueue.Clear();
192+
_batchQueueHead = 0;
170193
return null;
171-
172-
var nextBatch = _batchQueue[0];
173-
_batchQueue.RemoveAt(0);
194+
}
195+
196+
var nextBatch = _batchQueue[_batchQueueHead];
197+
_batchQueueHead++;
198+
_batchedTokenCountDirty = true;
174199
return nextBatch;
175200
}
201+
202+
void RequeueFront(IBatch batch)
203+
{
204+
Debug.Assert(_batchQueueHead > 0, "Cannot requeue batch when queue head is at zero.");
205+
_batchQueue[--_batchQueueHead] = batch;
206+
_batchedTokenCountDirty = true;
207+
}
208+
209+
// Remove batches that have already been consumed so the head index does not grow without bound.
210+
void CleanupQueue()
211+
{
212+
if (_batchQueueHead == 0)
213+
return;
214+
215+
if (_batchQueueHead >= _batchQueue.Count)
216+
{
217+
_batchQueue.Clear();
218+
_batchQueueHead = 0;
219+
return;
220+
}
221+
222+
if (_batchQueueHead > CleanupThreshold && _batchQueueHead > _batchQueue.Count / 2)
223+
{
224+
_batchQueue.RemoveRange(0, _batchQueueHead);
225+
_batchQueueHead = 0;
226+
}
227+
}
176228
}
177229

178230
/// <inheritdoc />
@@ -202,7 +254,7 @@ internal LLamaSeqId GetNextSequenceId()
202254
throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})");
203255

204256
// Find a batch with space for at least minCapacity tokens
205-
for (var i = 0; i < _batchQueue.Count; i++)
257+
for (var i = _batchQueueHead; i < _batchQueue.Count; i++)
206258
{
207259
var item = _batchQueue[i];
208260
if (item is not TokenBatch { Batch: var batch })
@@ -213,13 +265,17 @@ internal LLamaSeqId GetNextSequenceId()
213265
continue;
214266

215267
if (batch.TokenCount < Context.BatchSize)
216-
return (batch, Epoch + (uint)(i + 1) * 2);
268+
{
269+
_batchedTokenCountDirty = true;
270+
return (batch, Epoch + (uint)(i - _batchQueueHead + 1) * 2);
271+
}
217272
}
218273

219274
// Add a new batch to the end of the queue
220275
var end = new LLamaBatch();
221276
_batchQueue.Add(new TokenBatch(end));
222-
return (end, Epoch + (uint)_batchQueue.Count * 2);
277+
_batchedTokenCountDirty = true;
278+
return (end, Epoch + (uint)(_batchQueue.Count - _batchQueueHead) * 2);
223279
}
224280

225281
/// <summary>
@@ -234,7 +290,7 @@ internal LLamaSeqId GetNextSequenceId()
234290
throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})");
235291

236292
// Find a batch with space for at least minCapacity embeddings
237-
for (var i = 0; i < _batchQueue.Count; i++)
293+
for (var i = _batchQueueHead; i < _batchQueue.Count; i++)
238294
{
239295
var item = _batchQueue[i];
240296
if (item is not EmbeddingBatch { Batch: var batch })
@@ -245,13 +301,17 @@ internal LLamaSeqId GetNextSequenceId()
245301
continue;
246302

247303
if (batch.EmbeddingsCount < Context.BatchSize)
248-
return (batch, Epoch + (uint)(i + 1) * 2);
304+
{
305+
_batchedTokenCountDirty = true;
306+
return (batch, Epoch + (uint)(i - _batchQueueHead + 1) * 2);
307+
}
249308
}
250309

251310
// Add a new batch to the end of the queue
252311
var end = new LLamaBatchEmbeddings(Context.EmbeddingSize);
253312
_batchQueue.Add(new EmbeddingBatch(end));
254-
return (end, Epoch + (uint)_batchQueue.Count * 2);
313+
_batchedTokenCountDirty = true;
314+
return (end, Epoch + (uint)(_batchQueue.Count - _batchQueueHead) * 2);
255315
}
256316

257317
#region batches
@@ -286,4 +346,4 @@ public Task<DecodeResult> DecodeAsync(LLamaContext ctx, CancellationToken token)
286346
}
287347
}
288348
#endregion
289-
}
349+
}

0 commit comments

Comments
 (0)