-
Notifications
You must be signed in to change notification settings - Fork 682
Expand file tree
/
Copy pathDistributedCacheEventStreamStore.cs
More file actions
505 lines (426 loc) · 23.5 KB
/
DistributedCacheEventStreamStore.cs
File metadata and controls
505 lines (426 loc) · 23.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Protocol;
using System.Net.ServerSentEvents;
using System.Runtime.CompilerServices;
using System.Text.Json;
namespace ModelContextProtocol.Server;
/// <summary>
/// An <see cref="ISseEventStreamStore"/> implementation backed by <see cref="IDistributedCache"/>.
/// </summary>
/// <remarks>
/// <para>
/// This implementation stores SSE events in a distributed cache, enabling resumability across
/// multiple server instances. Event IDs are encoded with session, stream, and sequence information
/// to allow efficient retrieval of events after a given point.
/// </para>
/// <para>
/// The writer maintains in-memory state for sequence number generation, as there is guaranteed
/// to be only one writer per stream. Readers may be created from separate processes.
/// </para>
/// </remarks>
public sealed partial class DistributedCacheEventStreamStore : ISseEventStreamStore
{
private readonly IDistributedCache _cache;
private readonly DistributedCacheEventStreamStoreOptions _options;
private readonly ILogger _logger;
/// <summary>
/// Initializes a new instance of the <see cref="DistributedCacheEventStreamStore"/> class.
/// </summary>
/// <param name="cache">The distributed cache to use for storage.</param>
/// <param name="options">Optional configuration options for the store.</param>
/// <param name="logger">Optional logger for diagnostic output.</param>
public DistributedCacheEventStreamStore(IDistributedCache cache, DistributedCacheEventStreamStoreOptions? options = null, ILogger<DistributedCacheEventStreamStore>? logger = null)
{
Throw.IfNull(cache);
_cache = cache;
_options = options ?? new();
_logger = logger ?? NullLogger<DistributedCacheEventStreamStore>.Instance;
}
/// <inheritdoc />
public ValueTask<ISseEventStreamWriter> CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default)
{
Throw.IfNull(options);
LogStreamCreated(options.SessionId, options.StreamId, options.Mode);
var writer = new DistributedCacheEventStreamWriter(_cache, options.SessionId, options.StreamId, options.Mode, _options, _logger);
return new ValueTask<ISseEventStreamWriter>(writer);
}
/// <inheritdoc />
public async ValueTask DeleteStreamsForSessionAsync(string sessionId, CancellationToken cancellationToken = default)
{
Throw.IfNull(sessionId);
// Read the session index to find all streams for this session
var indexKey = CacheKeys.SessionIndex(sessionId);
var indexBytes = await _cache.GetAsync(indexKey, cancellationToken).ConfigureAwait(false);
if (indexBytes is null)
{
LogSessionIndexNotFound(sessionId);
return;
}
var index = JsonSerializer.Deserialize(indexBytes, DistributedCacheEventStreamStoreJsonUtilities.SessionIndexJsonTypeInfo);
if (index?.Streams is null)
{
LogSessionIndexDeserializationFailed(sessionId);
return;
}
// Delete all events and metadata for each stream
foreach (var stream in index.Streams)
{
// Delete all event keys for this stream
for (long seq = 1; seq <= stream.LastSequence; seq++)
{
var eventId = DistributedCacheEventIdFormatter.Format(sessionId, stream.StreamId, seq);
var eventKey = CacheKeys.Event(eventId);
await _cache.RemoveAsync(eventKey, cancellationToken).ConfigureAwait(false);
}
// Delete the stream metadata
var metadataKey = CacheKeys.StreamMetadata(sessionId, stream.StreamId);
await _cache.RemoveAsync(metadataKey, cancellationToken).ConfigureAwait(false);
}
// Delete the session index itself
await _cache.RemoveAsync(indexKey, cancellationToken).ConfigureAwait(false);
LogStreamsDeletedForSession(sessionId, index.Streams.Count);
}
/// <inheritdoc />
public async ValueTask<ISseEventStreamReader?> GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default)
{
Throw.IfNull(lastEventId);
// Parse the event ID to get session, stream, and sequence information
if (!DistributedCacheEventIdFormatter.TryParse(lastEventId, out var sessionId, out var streamId, out var sequence))
{
LogEventIdParsingFailed(lastEventId);
return null;
}
// Check if the stream exists by looking for its metadata
var metadataKey = CacheKeys.StreamMetadata(sessionId, streamId);
var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false);
if (metadataBytes is null)
{
LogStreamMetadataNotFound(sessionId, streamId);
return null;
}
var metadata = JsonSerializer.Deserialize(metadataBytes, DistributedCacheEventStreamStoreJsonUtilities.StreamMetadataJsonTypeInfo);
if (metadata is null)
{
LogStreamMetadataDeserializationFailed(sessionId, streamId);
return null;
}
var startSequence = sequence + 1;
LogStreamReaderCreated(sessionId, streamId, startSequence, metadata.LastSequence);
return new DistributedCacheEventStreamReader(_cache, sessionId, streamId, startSequence, metadata, _options, _logger);
}
/// <summary>
/// Provides methods for generating cache keys.
/// </summary>
/// <remarks>
/// Cache keys are versioned to allow format changes without conflicts with existing entries.
/// When the cache format changes, increment <see cref="Version"/> to invalidate old entries.
/// </remarks>
internal static class CacheKeys
{
/// <summary>
/// The current cache key version. Increment this when changing the cache format
/// to ensure old entries are ignored.
/// </summary>
private const string Version = "v1";
private const string Prefix = $"mcp:sse:{Version}:";
public static string StreamMetadata(string sessionId, string streamId)
{
var sessionIdBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(sessionId));
var streamIdBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(streamId));
return $"{Prefix}meta:{sessionIdBase64}:{streamIdBase64}";
}
public static string SessionIndex(string sessionId)
{
var sessionIdBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(sessionId));
return $"{Prefix}idx:{sessionIdBase64}";
}
public static string Event(string eventId)
=> $"{Prefix}event:{eventId}";
}
/// <summary>
/// Metadata about a stream stored in the cache.
/// </summary>
internal sealed class StreamMetadata
{
public SseEventStreamMode Mode { get; set; }
public bool IsCompleted { get; set; }
public long LastSequence { get; set; }
}
/// <summary>
/// Serialized representation of an SSE event stored in the cache.
/// </summary>
internal sealed class StoredEvent
{
public string? EventType { get; set; }
public string? EventId { get; set; }
public int? ReconnectionIntervalMs { get; set; }
public JsonRpcMessage? Data { get; set; }
}
/// <summary>
/// Index of all streams belonging to a session, stored in the cache.
/// </summary>
internal sealed class SessionIndex
{
public List<SessionStreamEntry> Streams { get; set; } = [];
}
/// <summary>
/// Entry in the session index representing a single stream.
/// </summary>
internal sealed class SessionStreamEntry
{
public string StreamId { get; set; } = string.Empty;
public long LastSequence { get; set; }
}
private sealed partial class DistributedCacheEventStreamWriter : ISseEventStreamWriter
{
private readonly IDistributedCache _cache;
private readonly string _sessionId;
private readonly string _streamId;
private SseEventStreamMode _mode;
private readonly DistributedCacheEventStreamStoreOptions _options;
private readonly ILogger _logger;
private long _sequence;
private bool _disposed;
public DistributedCacheEventStreamWriter(
IDistributedCache cache,
string sessionId,
string streamId,
SseEventStreamMode mode,
DistributedCacheEventStreamStoreOptions options,
ILogger logger)
{
_cache = cache;
_sessionId = sessionId;
_streamId = streamId;
_mode = mode;
_options = options;
_logger = logger;
}
public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
LogStreamModeChanged(_sessionId, _streamId, mode);
_mode = mode;
await UpdateMetadataAsync(isCompleted: false, cancellationToken).ConfigureAwait(false);
}
public async ValueTask<SseItem<JsonRpcMessage?>> WriteEventAsync(SseItem<JsonRpcMessage?> sseItem, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
// Skip if already has an event ID
if (sseItem.EventId is not null)
{
LogEventAlreadyHasId(_sessionId, _streamId, sseItem.EventId);
return sseItem;
}
// Generate a new sequence number and event ID
var sequence = Interlocked.Increment(ref _sequence);
var eventId = DistributedCacheEventIdFormatter.Format(_sessionId, _streamId, sequence);
var newItem = sseItem with { EventId = eventId };
// Store the event in the cache
var storedEvent = new StoredEvent
{
EventType = newItem.EventType,
EventId = eventId,
ReconnectionIntervalMs = newItem.ReconnectionInterval.HasValue
? (int)newItem.ReconnectionInterval.Value.TotalMilliseconds
: null,
Data = newItem.Data,
};
var eventBytes = JsonSerializer.SerializeToUtf8Bytes(storedEvent, DistributedCacheEventStreamStoreJsonUtilities.StoredEventJsonTypeInfo);
var eventKey = CacheKeys.Event(eventId);
await _cache.SetAsync(eventKey, eventBytes, new DistributedCacheEntryOptions
{
SlidingExpiration = _options.EventSlidingExpiration,
AbsoluteExpirationRelativeToNow = _options.EventAbsoluteExpiration,
}, cancellationToken).ConfigureAwait(false);
// Update metadata with the latest sequence
await UpdateMetadataAsync(isCompleted: false, cancellationToken).ConfigureAwait(false);
LogEventWritten(_sessionId, _streamId, eventId, sequence);
return newItem;
}
private async ValueTask UpdateMetadataAsync(bool isCompleted, CancellationToken cancellationToken)
{
var metadata = new StreamMetadata
{
Mode = _mode,
IsCompleted = isCompleted,
LastSequence = Interlocked.Read(ref _sequence),
};
var metadataBytes = JsonSerializer.SerializeToUtf8Bytes(metadata, DistributedCacheEventStreamStoreJsonUtilities.StreamMetadataJsonTypeInfo);
var metadataKey = CacheKeys.StreamMetadata(_sessionId, _streamId);
await _cache.SetAsync(metadataKey, metadataBytes, new DistributedCacheEntryOptions
{
SlidingExpiration = _options.MetadataSlidingExpiration,
AbsoluteExpirationRelativeToNow = _options.MetadataAbsoluteExpiration,
}, cancellationToken).ConfigureAwait(false);
// Update the session index with this stream's latest sequence
await UpdateSessionIndexAsync(metadata.LastSequence, cancellationToken).ConfigureAwait(false);
}
private async ValueTask UpdateSessionIndexAsync(long lastSequence, CancellationToken cancellationToken)
{
var indexKey = CacheKeys.SessionIndex(_sessionId);
var indexBytes = await _cache.GetAsync(indexKey, cancellationToken).ConfigureAwait(false);
var index = indexBytes is not null
? JsonSerializer.Deserialize(indexBytes, DistributedCacheEventStreamStoreJsonUtilities.SessionIndexJsonTypeInfo) ?? new SessionIndex()
: new SessionIndex();
var existingEntry = index.Streams.Find(s => s.StreamId == _streamId);
if (existingEntry is not null)
{
existingEntry.LastSequence = lastSequence;
}
else
{
index.Streams.Add(new SessionStreamEntry { StreamId = _streamId, LastSequence = lastSequence });
}
var updatedIndexBytes = JsonSerializer.SerializeToUtf8Bytes(index, DistributedCacheEventStreamStoreJsonUtilities.SessionIndexJsonTypeInfo);
await _cache.SetAsync(indexKey, updatedIndexBytes, new DistributedCacheEntryOptions
{
SlidingExpiration = _options.MetadataSlidingExpiration,
AbsoluteExpirationRelativeToNow = _options.MetadataAbsoluteExpiration,
}, cancellationToken).ConfigureAwait(false);
}
private void ThrowIfDisposed()
{
#if NET
ObjectDisposedException.ThrowIf(_disposed, this);
#else
if (_disposed)
{
throw new ObjectDisposedException(nameof(DistributedCacheEventStreamWriter));
}
#endif
}
public async ValueTask DisposeAsync()
{
if (_disposed)
{
return;
}
_disposed = true;
// Mark the stream as completed in the metadata
await UpdateMetadataAsync(isCompleted: true, CancellationToken.None).ConfigureAwait(false);
LogStreamWriterDisposed(_sessionId, _streamId, Interlocked.Read(ref _sequence));
}
[LoggerMessage(Level = LogLevel.Debug, Message = "Stream mode changed for session '{SessionId}', stream '{StreamId}' to {Mode}.")]
private partial void LogStreamModeChanged(string sessionId, string streamId, SseEventStreamMode mode);
[LoggerMessage(Level = LogLevel.Trace, Message = "Event already has ID '{EventId}' for session '{SessionId}', stream '{StreamId}'. Skipping ID generation.")]
private partial void LogEventAlreadyHasId(string sessionId, string streamId, string eventId);
[LoggerMessage(Level = LogLevel.Debug, Message = "Event written to session '{SessionId}', stream '{StreamId}' with ID '{EventId}' (sequence {Sequence}).")]
private partial void LogEventWritten(string sessionId, string streamId, string eventId, long sequence);
[LoggerMessage(Level = LogLevel.Information, Message = "Stream writer disposed for session '{SessionId}', stream '{StreamId}'. Total events written: {TotalEvents}.")]
private partial void LogStreamWriterDisposed(string sessionId, string streamId, long totalEvents);
}
private sealed partial class DistributedCacheEventStreamReader : ISseEventStreamReader
{
private readonly IDistributedCache _cache;
private readonly long _startSequence;
private readonly StreamMetadata _initialMetadata;
private readonly DistributedCacheEventStreamStoreOptions _options;
private readonly ILogger _logger;
public DistributedCacheEventStreamReader(
IDistributedCache cache,
string sessionId,
string streamId,
long startSequence,
StreamMetadata initialMetadata,
DistributedCacheEventStreamStoreOptions options,
ILogger logger)
{
_cache = cache;
SessionId = sessionId;
StreamId = streamId;
_startSequence = startSequence;
_initialMetadata = initialMetadata;
_options = options;
_logger = logger;
}
public string SessionId { get; }
public string StreamId { get; }
public async IAsyncEnumerable<SseItem<JsonRpcMessage?>> ReadEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Start from the sequence after the last received event
var currentSequence = _startSequence;
// Use the initial metadata passed to the constructor for the first read.
var lastSequence = _initialMetadata.LastSequence;
var isCompleted = _initialMetadata.IsCompleted;
var mode = _initialMetadata.Mode;
LogReadingEventsStarted(SessionId, StreamId, _startSequence, lastSequence);
while (!cancellationToken.IsCancellationRequested)
{
// Read all available events from currentSequence + 1 to lastSequence
for (; currentSequence <= lastSequence; currentSequence++)
{
cancellationToken.ThrowIfCancellationRequested();
var eventId = DistributedCacheEventIdFormatter.Format(SessionId, StreamId, currentSequence);
var eventKey = CacheKeys.Event(eventId);
var eventBytes = await _cache.GetAsync(eventKey, cancellationToken).ConfigureAwait(false)
?? throw new McpException($"SSE event with ID '{eventId}' was not found in the cache. The event may have expired.");
var storedEvent = JsonSerializer.Deserialize(eventBytes, DistributedCacheEventStreamStoreJsonUtilities.StoredEventJsonTypeInfo);
if (storedEvent is not null)
{
LogEventRead(SessionId, StreamId, eventId, currentSequence);
yield return new SseItem<JsonRpcMessage?>(storedEvent.Data, storedEvent.EventType)
{
EventId = storedEvent.EventId,
ReconnectionInterval = storedEvent.ReconnectionIntervalMs.HasValue
? TimeSpan.FromMilliseconds(storedEvent.ReconnectionIntervalMs.Value)
: null,
};
}
}
// If in polling mode, stop after returning currently available events
if (mode == SseEventStreamMode.Polling)
{
LogReadingEventsCompletedPolling(SessionId, StreamId, currentSequence - 1);
yield break;
}
// If the stream is completed and we've read all events, stop
if (isCompleted)
{
LogReadingEventsCompletedStreamEnded(SessionId, StreamId, currentSequence - 1);
yield break;
}
// Wait before polling again for new events
LogWaitingForNewEvents(SessionId, StreamId, _options.StreamReaderPollingInterval);
await Task.Delay(_options.StreamReaderPollingInterval, cancellationToken).ConfigureAwait(false);
// Refresh metadata to get the latest sequence and completion status
var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId);
var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false)
?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' was not found in the cache. The metadata may have expired.");
var currentMetadata = JsonSerializer.Deserialize(metadataBytes, DistributedCacheEventStreamStoreJsonUtilities.StreamMetadataJsonTypeInfo)
?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' could not be deserialized.");
lastSequence = currentMetadata.LastSequence;
isCompleted = currentMetadata.IsCompleted;
mode = currentMetadata.Mode;
}
}
[LoggerMessage(Level = LogLevel.Debug, Message = "Starting to read events for session '{SessionId}', stream '{StreamId}' starting at sequence {StartSequence}. Last available sequence: {LastSequence}.")]
private partial void LogReadingEventsStarted(string sessionId, string streamId, long startSequence, long lastSequence);
[LoggerMessage(Level = LogLevel.Trace, Message = "Event read from session '{SessionId}', stream '{StreamId}' with ID '{EventId}' (sequence {Sequence}).")]
private partial void LogEventRead(string sessionId, string streamId, string eventId, long sequence);
[LoggerMessage(Level = LogLevel.Debug, Message = "Reading events completed for session '{SessionId}', stream '{StreamId}' in polling mode. Last sequence read: {LastSequence}.")]
private partial void LogReadingEventsCompletedPolling(string sessionId, string streamId, long lastSequence);
[LoggerMessage(Level = LogLevel.Debug, Message = "Reading events completed for session '{SessionId}', stream '{StreamId}' as stream has ended. Last sequence read: {LastSequence}.")]
private partial void LogReadingEventsCompletedStreamEnded(string sessionId, string streamId, long lastSequence);
[LoggerMessage(Level = LogLevel.Trace, Message = "Waiting for new events on session '{SessionId}', stream '{StreamId}'. Polling interval: {PollingInterval}.")]
private partial void LogWaitingForNewEvents(string sessionId, string streamId, TimeSpan pollingInterval);
}
[LoggerMessage(Level = LogLevel.Information, Message = "Stream created for session '{SessionId}', stream '{StreamId}' with mode {Mode}.")]
private partial void LogStreamCreated(string sessionId, string streamId, SseEventStreamMode mode);
[LoggerMessage(Level = LogLevel.Debug, Message = "Stream reader created for session '{SessionId}', stream '{StreamId}' starting at sequence {StartSequence}. Last available sequence: {LastSequence}.")]
private partial void LogStreamReaderCreated(string sessionId, string streamId, long startSequence, long lastSequence);
[LoggerMessage(Level = LogLevel.Warning, Message = "Failed to parse event ID '{EventId}'. Unable to create stream reader.")]
private partial void LogEventIdParsingFailed(string eventId);
[LoggerMessage(Level = LogLevel.Debug, Message = "Stream metadata not found for session '{SessionId}', stream '{StreamId}'.")]
private partial void LogStreamMetadataNotFound(string sessionId, string streamId);
[LoggerMessage(Level = LogLevel.Warning, Message = "Failed to deserialize stream metadata for session '{SessionId}', stream '{StreamId}'.")]
private partial void LogStreamMetadataDeserializationFailed(string sessionId, string streamId);
[LoggerMessage(Level = LogLevel.Debug, Message = "Session index not found for session '{SessionId}'. No streams to delete.")]
private partial void LogSessionIndexNotFound(string sessionId);
[LoggerMessage(Level = LogLevel.Warning, Message = "Failed to deserialize session index for session '{SessionId}'.")]
private partial void LogSessionIndexDeserializationFailed(string sessionId);
[LoggerMessage(Level = LogLevel.Information, Message = "Deleted {StreamCount} stream(s) for session '{SessionId}'.")]
private partial void LogStreamsDeletedForSession(string sessionId, int streamCount);
}