Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,16 @@ public interface ISseEventStreamStore
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A reader for the event stream, or <c>null</c> if no matching stream is found.</returns>
ValueTask<ISseEventStreamReader?> GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default);

/// <summary>
/// Deletes all stored event streams and their associated events for the specified session.
/// </summary>
/// <param name="sessionId">The ID of the session whose streams should be deleted.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task representing the asynchronous operation.</returns>
/// <remarks>
/// This method is a best-effort operation. If the session does not exist or has no stored streams,
/// the method completes without error.
/// </remarks>
ValueTask DeleteStreamsForSessionAsync(string sessionId, CancellationToken cancellationToken = default);
}
104 changes: 104 additions & 0 deletions src/ModelContextProtocol/Server/DistributedCacheEventStreamStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,48 @@ public ValueTask<ISseEventStreamWriter> CreateStreamAsync(SseEventStreamOptions
return new ValueTask<ISseEventStreamWriter>(writer);
}

/// <inheritdoc />
public async ValueTask DeleteStreamsForSessionAsync(string sessionId, CancellationToken cancellationToken = default)
{
Throw.IfNull(sessionId);

// Read the session index to find all streams for this session
var indexKey = CacheKeys.SessionIndex(sessionId);
var indexBytes = await _cache.GetAsync(indexKey, cancellationToken).ConfigureAwait(false);
if (indexBytes is null)
{
LogSessionIndexNotFound(sessionId);
return;
}

var index = JsonSerializer.Deserialize(indexBytes, DistributedCacheEventStreamStoreJsonUtilities.SessionIndexJsonTypeInfo);
if (index?.Streams is null)
{
LogSessionIndexDeserializationFailed(sessionId);
return;
}

// Delete all events and metadata for each stream
foreach (var stream in index.Streams)
{
// Delete all event keys for this stream
for (long seq = 1; seq <= stream.LastSequence; seq++)
{
var eventId = DistributedCacheEventIdFormatter.Format(sessionId, stream.StreamId, seq);
var eventKey = CacheKeys.Event(eventId);
await _cache.RemoveAsync(eventKey, cancellationToken).ConfigureAwait(false);
}

// Delete the stream metadata
var metadataKey = CacheKeys.StreamMetadata(sessionId, stream.StreamId);
await _cache.RemoveAsync(metadataKey, cancellationToken).ConfigureAwait(false);
}

// Delete the session index itself
await _cache.RemoveAsync(indexKey, cancellationToken).ConfigureAwait(false);
LogStreamsDeletedForSession(sessionId, index.Streams.Count);
}

/// <inheritdoc />
public async ValueTask<ISseEventStreamReader?> GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -107,6 +149,12 @@ public static string StreamMetadata(string sessionId, string streamId)
return $"{Prefix}meta:{sessionIdBase64}:{streamIdBase64}";
}

public static string SessionIndex(string sessionId)
{
var sessionIdBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(sessionId));
return $"{Prefix}idx:{sessionIdBase64}";
}

public static string Event(string eventId)
=> $"{Prefix}event:{eventId}";
}
Expand All @@ -132,6 +180,23 @@ internal sealed class StoredEvent
public JsonRpcMessage? Data { get; set; }
}

/// <summary>
/// Index of all streams belonging to a session, stored in the cache.
/// </summary>
internal sealed class SessionIndex
{
public List<SessionStreamEntry> Streams { get; set; } = [];
}

/// <summary>
/// Entry in the session index representing a single stream.
/// </summary>
internal sealed class SessionStreamEntry
{
public string StreamId { get; set; } = string.Empty;
public long LastSequence { get; set; }
}

private sealed partial class DistributedCacheEventStreamWriter : ISseEventStreamWriter
{
private readonly IDistributedCache _cache;
Expand Down Expand Up @@ -228,6 +293,36 @@ private async ValueTask UpdateMetadataAsync(bool isCompleted, CancellationToken
SlidingExpiration = _options.MetadataSlidingExpiration,
AbsoluteExpirationRelativeToNow = _options.MetadataAbsoluteExpiration,
}, cancellationToken).ConfigureAwait(false);

// Update the session index with this stream's latest sequence
await UpdateSessionIndexAsync(metadata.LastSequence, cancellationToken).ConfigureAwait(false);
}

private async ValueTask UpdateSessionIndexAsync(long lastSequence, CancellationToken cancellationToken)
{
var indexKey = CacheKeys.SessionIndex(_sessionId);
var indexBytes = await _cache.GetAsync(indexKey, cancellationToken).ConfigureAwait(false);

var index = indexBytes is not null
? JsonSerializer.Deserialize(indexBytes, DistributedCacheEventStreamStoreJsonUtilities.SessionIndexJsonTypeInfo) ?? new SessionIndex()
: new SessionIndex();

var existingEntry = index.Streams.Find(s => s.StreamId == _streamId);
if (existingEntry is not null)
{
existingEntry.LastSequence = lastSequence;
}
else
{
index.Streams.Add(new SessionStreamEntry { StreamId = _streamId, LastSequence = lastSequence });
}

var updatedIndexBytes = JsonSerializer.SerializeToUtf8Bytes(index, DistributedCacheEventStreamStoreJsonUtilities.SessionIndexJsonTypeInfo);
await _cache.SetAsync(indexKey, updatedIndexBytes, new DistributedCacheEntryOptions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how we can easily prevent concurrent calls to UpdateSessionIndexAsync for concurrent streams from clobbering the SessionIndex causing streams to be lost. This might be tractable if we knew a given session could only be owned by this process by using a lock, but with ISessionMigrationHandler we have to consider we might have the same session concurrently handled by different machines.

{
SlidingExpiration = _options.MetadataSlidingExpiration,
AbsoluteExpirationRelativeToNow = _options.MetadataAbsoluteExpiration,
}, cancellationToken).ConfigureAwait(false);
}

private void ThrowIfDisposed()
Expand Down Expand Up @@ -398,4 +493,13 @@ public DistributedCacheEventStreamReader(

[LoggerMessage(Level = LogLevel.Warning, Message = "Failed to deserialize stream metadata for session '{SessionId}', stream '{StreamId}'.")]
private partial void LogStreamMetadataDeserializationFailed(string sessionId, string streamId);

[LoggerMessage(Level = LogLevel.Debug, Message = "Session index not found for session '{SessionId}'. No streams to delete.")]
private partial void LogSessionIndexNotFound(string sessionId);

[LoggerMessage(Level = LogLevel.Warning, Message = "Failed to deserialize session index for session '{SessionId}'.")]
private partial void LogSessionIndexDeserializationFailed(string sessionId);

[LoggerMessage(Level = LogLevel.Information, Message = "Deleted {StreamCount} stream(s) for session '{SessionId}'.")]
private partial void LogStreamsDeletedForSession(string sessionId, int streamCount);
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ internal static partial class DistributedCacheEventStreamStoreJsonUtilities
public static JsonTypeInfo<DistributedCacheEventStreamStore.StoredEvent> StoredEventJsonTypeInfo { get; } =
(JsonTypeInfo<DistributedCacheEventStreamStore.StoredEvent>)DefaultOptions.GetTypeInfo(typeof(DistributedCacheEventStreamStore.StoredEvent));

/// <summary>
/// Gets the <see cref="JsonTypeInfo{T}"/> for <see cref="DistributedCacheEventStreamStore.SessionIndex"/>.
/// </summary>
public static JsonTypeInfo<DistributedCacheEventStreamStore.SessionIndex> SessionIndexJsonTypeInfo { get; } =
(JsonTypeInfo<DistributedCacheEventStreamStore.SessionIndex>)DefaultOptions.GetTypeInfo(typeof(DistributedCacheEventStreamStore.SessionIndex));

private static JsonSerializerOptions CreateDefaultOptions()
{
// Copy the configuration from McpJsonUtilities.DefaultOptions.
Expand All @@ -56,5 +62,6 @@ private static JsonSerializerOptions CreateDefaultOptions()
GenerationMode = JsonSourceGenerationMode.Metadata)]
[JsonSerializable(typeof(DistributedCacheEventStreamStore.StreamMetadata))]
[JsonSerializable(typeof(DistributedCacheEventStreamStore.StoredEvent))]
[JsonSerializable(typeof(DistributedCacheEventStreamStore.SessionIndex))]
private sealed partial class JsonContext : JsonSerializerContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ public ValueTask<ISseEventStreamWriter> CreateStreamAsync(SseEventStreamOptions
public ValueTask<ISseEventStreamReader?> GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default)
=> throw new NotSupportedException("This test store does not support reading streams.");

public ValueTask DeleteStreamsForSessionAsync(string sessionId, CancellationToken cancellationToken = default)
=> default;

private sealed class BlockingEventStreamWriter : ISseEventStreamWriter
{
private readonly BlockingEventStreamStore _store;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,28 @@ private void TrackEvent(string eventId, StreamState stream, long sequence, TimeS
Interlocked.Increment(ref _storeEventCallCount);
}

/// <inheritdoc />
public ValueTask DeleteStreamsForSessionAsync(string sessionId, CancellationToken cancellationToken = default)
{
// Find all streams belonging to this session
var keysToRemove = _streams.Keys.Where(k => k.StartsWith($"{sessionId}:", StringComparison.Ordinal)).ToList();

foreach (var key in keysToRemove)
{
if (_streams.TryRemove(key, out var state))
{
// Remove all events belonging to this stream from the event lookup
var eventKeysToRemove = _eventLookup.Where(kvp => kvp.Value.Stream == state).Select(kvp => kvp.Key).ToList();
foreach (var eventKey in eventKeysToRemove)
{
_eventLookup.TryRemove(eventKey, out _);
}
}
}

return default;
}

private static string GetStreamKey(string sessionId, string streamId) => $"{sessionId}:{streamId}";

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,133 @@ public void EventIdFormatter_TryParse_ReturnsFalse_ForNonNumericSequence()
Assert.False(parsed);
}

[Fact]
public async Task DeleteStreamsForSessionAsync_ThrowsArgumentNullException_WhenSessionIdIsNull()
{
// Arrange
var cache = CreateMemoryCache();
var store = new DistributedCacheEventStreamStore(cache);

// Act & Assert
await Assert.ThrowsAsync<ArgumentNullException>("sessionId",
async () => await store.DeleteStreamsForSessionAsync(null!, CancellationToken));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside of these unit tests, I don't see anything that calls DeleteStreamsForSessionAsync. It'd be easy enough to do so in StreamableHttpServerTransport.DisposeAsync, but that's going to run into problems once we add ISessionMigrationHandler from #1270. It also isn't helpful in stateless mode if/when we add support for that.

}

[Fact]
public async Task DeleteStreamsForSessionAsync_NoOp_WhenSessionDoesNotExist()
{
// Arrange
var cache = CreateMemoryCache();
var store = new DistributedCacheEventStreamStore(cache);

// Act - should not throw
await store.DeleteStreamsForSessionAsync("nonexistent-session", CancellationToken);
}

[Fact]
public async Task DeleteStreamsForSessionAsync_RemovesAllStreamsAndEvents()
{
// Arrange
var cache = CreateMemoryCache();
var store = new DistributedCacheEventStreamStore(cache);

var writer = await store.CreateStreamAsync(new SseEventStreamOptions
{
SessionId = "session-1",
StreamId = "stream-1",
Mode = SseEventStreamMode.Streaming
}, CancellationToken);

var item1 = await writer.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
var item2 = await writer.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
await writer.DisposeAsync();

// Verify events are readable before deletion
var reader = await store.GetStreamReaderAsync(item1.EventId!, CancellationToken);
Assert.NotNull(reader);

// Act
await store.DeleteStreamsForSessionAsync("session-1", CancellationToken);

// Assert - events should no longer be readable
var readerAfterDelete = await store.GetStreamReaderAsync(item1.EventId!, CancellationToken);
Assert.Null(readerAfterDelete);

var readerAfterDelete2 = await store.GetStreamReaderAsync(item2.EventId!, CancellationToken);
Assert.Null(readerAfterDelete2);
}

[Fact]
public async Task DeleteStreamsForSessionAsync_DoesNotAffectOtherSessions()
{
// Arrange
var cache = CreateMemoryCache();
var store = new DistributedCacheEventStreamStore(cache);

// Create streams for two sessions
var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions
{
SessionId = "session-1",
StreamId = "stream-1",
Mode = SseEventStreamMode.Streaming
}, CancellationToken);
var item1 = await writer1.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
await writer1.DisposeAsync();

var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions
{
SessionId = "session-2",
StreamId = "stream-1",
Mode = SseEventStreamMode.Streaming
}, CancellationToken);
var item2 = await writer2.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
await writer2.DisposeAsync();

// Act - delete only session-1
await store.DeleteStreamsForSessionAsync("session-1", CancellationToken);

// Assert - session-1 events should be gone
var reader1 = await store.GetStreamReaderAsync(item1.EventId!, CancellationToken);
Assert.Null(reader1);

// Assert - session-2 events should still be readable
var reader2 = await store.GetStreamReaderAsync(item2.EventId!, CancellationToken);
Assert.NotNull(reader2);
}

[Fact]
public async Task DeleteStreamsForSessionAsync_RemovesMultipleStreamsInSameSession()
{
// Arrange
var cache = CreateMemoryCache();
var store = new DistributedCacheEventStreamStore(cache);

var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions
{
SessionId = "session-1",
StreamId = "stream-a",
Mode = SseEventStreamMode.Streaming
}, CancellationToken);
var itemA = await writer1.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
await writer1.DisposeAsync();

var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions
{
SessionId = "session-1",
StreamId = "stream-b",
Mode = SseEventStreamMode.Streaming
}, CancellationToken);
var itemB = await writer2.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
await writer2.DisposeAsync();

// Act
await store.DeleteStreamsForSessionAsync("session-1", CancellationToken);

// Assert - both streams should be gone
Assert.Null(await store.GetStreamReaderAsync(itemA.EventId!, CancellationToken));
Assert.Null(await store.GetStreamReaderAsync(itemB.EventId!, CancellationToken));
}

/// <summary>
/// A distributed cache that tracks all operations for verification in tests.
/// Supports tracking Set calls, counting metadata reads, and simulating metadata/event expiration.
Expand Down
Loading