diff --git a/.claude/settings.local.json b/.claude/settings.local.json
new file mode 100644
index 0000000..a736846
--- /dev/null
+++ b/.claude/settings.local.json
@@ -0,0 +1,12 @@
+{
+ "permissions": {
+ "allow": [
+ "Bash(tree:*)",
+ "Bash(dotnet build:*)",
+ "Bash(dotnet test:*)",
+ "Bash(git checkout:*)",
+ "Bash(git push:*)",
+ "Bash(git remote add:*)"
+ ]
+ }
+}
diff --git a/.editorconfig b/.editorconfig
index 2087cec..2024237 100644
--- a/.editorconfig
+++ b/.editorconfig
@@ -453,6 +453,20 @@ dotnet_diagnostic.IDE0200.severity = warning
dotnet_style_allow_multiple_blank_lines_experimental = false
dotnet_diagnostic.IDE2000.severity = warning
+# SignalR Hub methods are called by string name from clients - suppress async naming rule
+[**/Hubs/*.cs]
+dotnet_naming_rule.async_methods_should_end_with_async.severity = none
+
+# ASP.NET Controller actions don't typically follow async naming conventions
+[**/Controllers/*.cs]
+dotnet_naming_rule.async_methods_should_end_with_async.severity = none
+
+# Orleans grains with [PersistentState] attributes cannot use primary constructors
+# because attributes on constructor parameters cannot be applied to primary constructor parameters
+[**/SignalRConnectionHeartbeatGrain.cs]
+csharp_style_prefer_primary_constructors = false:none
+dotnet_diagnostic.IDE0290.severity = none
+
# Verify settings for test files
[*.{received,verified}.{txt,xml,json}]
charset = utf-8-bom
diff --git a/.gitignore b/.gitignore
index ea566a0..8b3d25e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -646,4 +646,7 @@ MigrationBackup/
# Ionide (cross platform F# VS Code tools) working folder
.ionide/
-# End of https://www.toptal.com/developers/gitignore/api/intellij,intellij+all,macos,linux,windows,visualstudio,visualstudiocode,rider
\ No newline at end of file
+# End of https://www.toptal.com/developers/gitignore/api/intellij,intellij+all,macos,linux,windows,visualstudio,visualstudiocode,rider
+
+# Claude Code temporary directories
+tmpclaude-*/
\ No newline at end of file
diff --git a/ManagedCode.Orleans.SignalR.Client/Properties/launchSettings.json b/ManagedCode.Orleans.SignalR.Client/Properties/launchSettings.json
new file mode 100644
index 0000000..015eafc
--- /dev/null
+++ b/ManagedCode.Orleans.SignalR.Client/Properties/launchSettings.json
@@ -0,0 +1,12 @@
+{
+ "profiles": {
+ "ManagedCode.Orleans.SignalR.Client": {
+ "commandName": "Project",
+ "launchBrowser": true,
+ "environmentVariables": {
+ "ASPNETCORE_ENVIRONMENT": "Development"
+ },
+ "applicationUrl": "https://localhost:56460;http://localhost:56463"
+ }
+ }
+}
\ No newline at end of file
diff --git a/ManagedCode.Orleans.SignalR.Core/Config/OrleansSignalROptions.cs b/ManagedCode.Orleans.SignalR.Core/Config/OrleansSignalROptions.cs
index 919798c..09ee66d 100644
--- a/ManagedCode.Orleans.SignalR.Core/Config/OrleansSignalROptions.cs
+++ b/ManagedCode.Orleans.SignalR.Core/Config/OrleansSignalROptions.cs
@@ -54,4 +54,102 @@ public class OrleansSignalROptions
/// Used as a hint when determining how many partitions to allocate dynamically.
///
public int GroupsPerPartitionHint { get; set; } = 1_000;
+
+ ///
+ /// Maximum number of messages to queue per user when they are disconnected.
+ /// Oldest messages are dropped when the limit is exceeded.
+ /// The default value is 100.
+ ///
+ public int MaxQueuedMessagesPerUser { get; set; } = 100;
+
+ ///
+ /// Number of consecutive failures before an observer is considered dead and removed.
+ /// Set to 0 to disable failure tracking.
+ /// The default value is 3.
+ ///
+ public int ObserverFailureThreshold { get; set; } = 3;
+
+ ///
+ /// Time window for counting observer failures. Failures older than this are forgotten.
+ /// The default value is 30 seconds.
+ ///
+ public TimeSpan ObserverFailureWindow { get; set; } = TimeSpan.FromSeconds(30);
+
+ ///
+ /// Enables circuit breaker pattern for observers to prevent cascade failures.
+ /// When enabled, failing observers are temporarily blocked from receiving messages.
+ /// The default value is true.
+ ///
+ public bool EnableCircuitBreaker { get; set; } = true;
+
+ ///
+ /// Duration to keep the circuit open (blocking requests) after failure threshold is reached.
+ /// After this duration, the circuit transitions to half-open state for testing.
+ /// The default value is 30 seconds.
+ ///
+ public TimeSpan CircuitBreakerOpenDuration { get; set; } = TimeSpan.FromSeconds(30);
+
+ ///
+ /// Interval between test requests when circuit is in half-open state.
+ /// The default value is 5 seconds.
+ ///
+ public TimeSpan CircuitBreakerHalfOpenTestInterval { get; set; } = TimeSpan.FromSeconds(5);
+
+ ///
+ /// Grace period before an observer is hard-removed after a failure.
+ /// During this period, messages are buffered and replayed if the observer recovers.
+ /// This handles timing edge cases like GC pauses, network latency, or silo overload.
+ /// Set to TimeSpan.Zero to disable grace period buffering.
+ /// The default value is 10 seconds.
+ ///
+ public TimeSpan ObserverGracePeriod { get; set; } = TimeSpan.FromSeconds(10);
+
+ ///
+ /// Maximum number of messages to buffer per observer during the grace period.
+ /// Oldest messages are dropped when the limit is exceeded.
+ /// The default value is 50.
+ ///
+ public int MaxBufferedMessagesPerObserver { get; set; } = 50;
+
+ ///
+ /// Maximum number of connections allowed per partition grain.
+ /// New connections are rejected when the limit is exceeded.
+ /// Set to 0 to disable connection limits (not recommended for production).
+ /// The default value is 100,000.
+ ///
+ public int MaxConnectionsPerPartition { get; set; } = 100_000;
+
+ ///
+ /// Maximum number of groups per partition grain.
+ /// New groups are rejected when the limit is exceeded.
+ /// Set to 0 to disable group limits.
+ /// The default value is 50,000.
+ ///
+ public int MaxGroupsPerPartition { get; set; } = 50_000;
+
+ ///
+ /// Timeout for slow client message delivery.
+ /// Connections that cannot receive messages within this time may be terminated.
+ /// The default value is 10 seconds.
+ ///
+ public TimeSpan SlowClientTimeout { get; set; } = TimeSpan.FromSeconds(10);
+
+ ///
+ /// Enables backpressure handling for slow clients.
+ /// When enabled, messages to slow clients are dropped or the connection is terminated.
+ /// The default value is true.
+ ///
+ public bool EnableSlowClientHandling { get; set; } = true;
+
+ ///
+ /// Maximum number of pending messages allowed per connection before backpressure is applied.
+ /// The default value is 1000.
+ ///
+ public int MaxPendingMessagesPerConnection { get; set; } = 1000;
+
+ ///
+ /// Enables metrics collection for monitoring and diagnostics.
+ /// The default value is true.
+ ///
+ public bool EnableMetrics { get; set; } = true;
}
diff --git a/ManagedCode.Orleans.SignalR.Core/Diagnostics/SignalRMetrics.cs b/ManagedCode.Orleans.SignalR.Core/Diagnostics/SignalRMetrics.cs
new file mode 100644
index 0000000..0501848
--- /dev/null
+++ b/ManagedCode.Orleans.SignalR.Core/Diagnostics/SignalRMetrics.cs
@@ -0,0 +1,365 @@
+using System;
+using System.Diagnostics;
+using System.Diagnostics.Metrics;
+using System.Threading;
+
+namespace ManagedCode.Orleans.SignalR.Core.Diagnostics;
+
+///
+/// Provides metrics for monitoring Orleans SignalR backplane performance.
+/// Uses System.Diagnostics.Metrics for .NET 10 compatibility with OpenTelemetry.
+///
+public sealed class SignalRMetrics : IDisposable
+{
+ ///
+ /// The meter name used for all Orleans SignalR metrics.
+ ///
+ public const string MeterName = "ManagedCode.Orleans.SignalR";
+
+ private readonly Meter _meter;
+
+ // Connection metrics
+ private readonly Counter _connectionsTotal;
+ private readonly Counter _disconnectionsTotal;
+ private readonly UpDownCounter _activeConnections;
+
+ // Message metrics
+ private readonly Counter _messagesSentTotal;
+ private readonly Counter _messagesReceivedTotal;
+ private readonly Counter _messagesDroppedTotal;
+ private readonly Counter _messagesBufferedTotal;
+ private readonly Histogram _messageDeliveryDuration;
+
+ // Observer health metrics
+ private readonly Counter _observerFailuresTotal;
+ private readonly Counter _observersMarkedDeadTotal;
+ private readonly Counter _circuitBreakersOpenedTotal;
+ private readonly Counter _circuitBreakersClosedTotal;
+ private readonly UpDownCounter _observersInGracePeriod;
+
+ // Partition metrics
+ private readonly ObservableGauge _connectionPartitionCount;
+ private readonly ObservableGauge _groupPartitionCount;
+
+ // Internal state for observable gauges
+ private int _currentConnectionPartitionCount;
+ private int _currentGroupPartitionCount;
+
+ ///
+ /// Gets the singleton instance of SignalRMetrics.
+ ///
+ public static SignalRMetrics Instance { get; } = new();
+
+ private SignalRMetrics()
+ {
+ _meter = new Meter(MeterName, "1.0.0");
+
+ // Connection metrics
+ _connectionsTotal = _meter.CreateCounter(
+ "signalr.connections.total",
+ unit: "{connection}",
+ description: "Total number of SignalR connections established");
+
+ _disconnectionsTotal = _meter.CreateCounter(
+ "signalr.disconnections.total",
+ unit: "{connection}",
+ description: "Total number of SignalR connections closed");
+
+ _activeConnections = _meter.CreateUpDownCounter(
+ "signalr.connections.active",
+ unit: "{connection}",
+ description: "Number of currently active SignalR connections");
+
+ // Message metrics
+ _messagesSentTotal = _meter.CreateCounter(
+ "signalr.messages.sent.total",
+ unit: "{message}",
+ description: "Total number of messages sent to clients");
+
+ _messagesReceivedTotal = _meter.CreateCounter(
+ "signalr.messages.received.total",
+ unit: "{message}",
+ description: "Total number of messages received from clients");
+
+ _messagesDroppedTotal = _meter.CreateCounter(
+ "signalr.messages.dropped.total",
+ unit: "{message}",
+ description: "Total number of messages dropped due to errors or backpressure");
+
+ _messagesBufferedTotal = _meter.CreateCounter(
+ "signalr.messages.buffered.total",
+ unit: "{message}",
+ description: "Total number of messages buffered during grace periods");
+
+ _messageDeliveryDuration = _meter.CreateHistogram(
+ "signalr.message.delivery.duration",
+ unit: "ms",
+ description: "Time taken to deliver a message to clients");
+
+ // Observer health metrics
+ _observerFailuresTotal = _meter.CreateCounter(
+ "signalr.observer.failures.total",
+ unit: "{failure}",
+ description: "Total number of observer delivery failures");
+
+ _observersMarkedDeadTotal = _meter.CreateCounter(
+ "signalr.observer.dead.total",
+ unit: "{observer}",
+ description: "Total number of observers marked as dead");
+
+ _circuitBreakersOpenedTotal = _meter.CreateCounter(
+ "signalr.circuit_breaker.opened.total",
+ unit: "{circuit}",
+ description: "Total number of times circuit breakers were opened");
+
+ _circuitBreakersClosedTotal = _meter.CreateCounter(
+ "signalr.circuit_breaker.closed.total",
+ unit: "{circuit}",
+ description: "Total number of times circuit breakers were closed");
+
+ _observersInGracePeriod = _meter.CreateUpDownCounter(
+ "signalr.observer.grace_period",
+ unit: "{observer}",
+ description: "Number of observers currently in grace period");
+
+ // Partition metrics
+ _connectionPartitionCount = _meter.CreateObservableGauge(
+ "signalr.partitions.connection.count",
+ () => Volatile.Read(ref _currentConnectionPartitionCount),
+ unit: "{partition}",
+ description: "Current number of connection partitions");
+
+ _groupPartitionCount = _meter.CreateObservableGauge(
+ "signalr.partitions.group.count",
+ () => Volatile.Read(ref _currentGroupPartitionCount),
+ unit: "{partition}",
+ description: "Current number of group partitions");
+ }
+
+ ///
+ /// Records a new connection.
+ ///
+ public void RecordConnectionEstablished(string hubName)
+ {
+ var tags = new TagList { { "hub", hubName } };
+ _connectionsTotal.Add(1, tags);
+ _activeConnections.Add(1, tags);
+ }
+
+ ///
+ /// Records a connection disconnection.
+ ///
+ public void RecordConnectionClosed(string hubName)
+ {
+ var tags = new TagList { { "hub", hubName } };
+ _disconnectionsTotal.Add(1, tags);
+ _activeConnections.Add(-1, tags);
+ }
+
+ ///
+ /// Records a message sent to clients.
+ ///
+ public void RecordMessageSent(string hubName, string targetType, int recipientCount = 1)
+ {
+ var tags = new TagList
+ {
+ { "hub", hubName },
+ { "target", targetType }
+ };
+ _messagesSentTotal.Add(recipientCount, tags);
+ }
+
+ ///
+ /// Records a message received from a client.
+ ///
+ public void RecordMessageReceived(string hubName)
+ {
+ var tags = new TagList { { "hub", hubName } };
+ _messagesReceivedTotal.Add(1, tags);
+ }
+
+ ///
+ /// Records a dropped message.
+ ///
+ public void RecordMessageDropped(string hubName, string reason)
+ {
+ var tags = new TagList
+ {
+ { "hub", hubName },
+ { "reason", reason }
+ };
+ _messagesDroppedTotal.Add(1, tags);
+ }
+
+ ///
+ /// Records a buffered message during grace period.
+ ///
+ public void RecordMessageBuffered(string hubName)
+ {
+ var tags = new TagList { { "hub", hubName } };
+ _messagesBufferedTotal.Add(1, tags);
+ }
+
+ ///
+ /// Records the duration of message delivery.
+ ///
+ public void RecordMessageDeliveryDuration(string hubName, double durationMs)
+ {
+ var tags = new TagList { { "hub", hubName } };
+ _messageDeliveryDuration.Record(durationMs, tags);
+ }
+
+ ///
+ /// Records an observer failure.
+ ///
+ public void RecordObserverFailure(string hubName, string failureType)
+ {
+ var tags = new TagList
+ {
+ { "hub", hubName },
+ { "failure_type", failureType }
+ };
+ _observerFailuresTotal.Add(1, tags);
+ }
+
+ ///
+ /// Records an observer marked as dead.
+ ///
+ public void RecordObserverDead(string hubName)
+ {
+ var tags = new TagList { { "hub", hubName } };
+ _observersMarkedDeadTotal.Add(1, tags);
+ }
+
+ ///
+ /// Records a circuit breaker opening.
+ ///
+ public void RecordCircuitBreakerOpened(string hubName)
+ {
+ var tags = new TagList { { "hub", hubName } };
+ _circuitBreakersOpenedTotal.Add(1, tags);
+ }
+
+ ///
+ /// Records a circuit breaker closing.
+ ///
+ public void RecordCircuitBreakerClosed(string hubName)
+ {
+ var tags = new TagList { { "hub", hubName } };
+ _circuitBreakersClosedTotal.Add(1, tags);
+ }
+
+ ///
+ /// Records an observer entering grace period.
+ ///
+ public void RecordGracePeriodStarted(string hubName)
+ {
+ var tags = new TagList { { "hub", hubName } };
+ _observersInGracePeriod.Add(1, tags);
+ }
+
+ ///
+ /// Records an observer exiting grace period.
+ ///
+ public void RecordGracePeriodEnded(string hubName)
+ {
+ var tags = new TagList { { "hub", hubName } };
+ _observersInGracePeriod.Add(-1, tags);
+ }
+
+ ///
+ /// Updates the current connection partition count.
+ ///
+ public void SetConnectionPartitionCount(int count)
+ {
+ Volatile.Write(ref _currentConnectionPartitionCount, count);
+ }
+
+ ///
+ /// Updates the current group partition count.
+ ///
+ public void SetGroupPartitionCount(int count)
+ {
+ Volatile.Write(ref _currentGroupPartitionCount, count);
+ }
+
+ ///
+ /// Creates a scope for measuring message delivery duration.
+ ///
+ public MessageDeliveryScope StartMessageDelivery(string hubName)
+ {
+ return new MessageDeliveryScope(this, hubName);
+ }
+
+ ///
+ /// Disposes the metrics meter.
+ ///
+ public void Dispose()
+ {
+ _meter.Dispose();
+ }
+
+ ///
+ /// Scope for measuring message delivery duration.
+ ///
+ public readonly struct MessageDeliveryScope : IDisposable
+ {
+ private readonly SignalRMetrics _metrics;
+ private readonly string _hubName;
+ private readonly long _startTimestamp;
+
+ internal MessageDeliveryScope(SignalRMetrics metrics, string hubName)
+ {
+ _metrics = metrics;
+ _hubName = hubName;
+ _startTimestamp = Stopwatch.GetTimestamp();
+ }
+
+ ///
+ /// Completes the measurement and records the duration.
+ ///
+ public void Dispose()
+ {
+ var elapsed = Stopwatch.GetElapsedTime(_startTimestamp);
+ _metrics.RecordMessageDeliveryDuration(_hubName, elapsed.TotalMilliseconds);
+ }
+ }
+}
+
+///
+/// Activity source for distributed tracing of SignalR operations.
+///
+public static class SignalRActivitySource
+{
+ ///
+ /// The activity source name.
+ ///
+ public const string SourceName = "ManagedCode.Orleans.SignalR";
+
+ ///
+ /// Gets the activity source for SignalR operations.
+ ///
+ public static ActivitySource Source { get; } = new(SourceName, "1.0.0");
+
+ ///
+ /// Starts an activity for sending a message.
+ ///
+ public static Activity? StartSendMessage(string hubName, string targetType)
+ {
+ var activity = Source.StartActivity("SignalR.SendMessage", ActivityKind.Producer);
+ activity?.SetTag("signalr.hub", hubName);
+ activity?.SetTag("signalr.target_type", targetType);
+ return activity;
+ }
+
+ ///
+ /// Starts an activity for a grain operation.
+ ///
+ public static Activity? StartGrainOperation(string grainType, string operation)
+ {
+ var activity = Source.StartActivity($"SignalR.{grainType}.{operation}", ActivityKind.Internal);
+ activity?.SetTag("signalr.grain_type", grainType);
+ activity?.SetTag("signalr.operation", operation);
+ return activity;
+ }
+}
diff --git a/ManagedCode.Orleans.SignalR.Core/Helpers/CollectionPool.cs b/ManagedCode.Orleans.SignalR.Core/Helpers/CollectionPool.cs
new file mode 100644
index 0000000..a4dc090
--- /dev/null
+++ b/ManagedCode.Orleans.SignalR.Core/Helpers/CollectionPool.cs
@@ -0,0 +1,173 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+
+namespace ManagedCode.Orleans.SignalR.Core.Helpers;
+
+///
+/// Provides pooling for common collection types to reduce allocations in hot paths.
+/// Uses thread-safe concurrent bags for lock-free pooling.
+///
+public static class CollectionPool
+{
+ private const int MaxPoolSize = 256;
+
+ private static readonly ConcurrentBag> _stringHashSetPool = new();
+ private static readonly ConcurrentBag> _stringListPool = new();
+ private static readonly ConcurrentBag>> _intListDictionaryPool = new();
+
+ ///
+ /// Gets a HashSet<string> from the pool or creates a new one.
+ ///
+ public static HashSet GetStringHashSet()
+ {
+ if (_stringHashSetPool.TryTake(out var set))
+ {
+ return set;
+ }
+
+ return new HashSet(StringComparer.Ordinal);
+ }
+
+ ///
+ /// Returns a HashSet<string> to the pool after clearing it.
+ ///
+ public static void Return(HashSet set)
+ {
+ if (set is null || _stringHashSetPool.Count >= MaxPoolSize)
+ {
+ return;
+ }
+
+ set.Clear();
+ _stringHashSetPool.Add(set);
+ }
+
+ ///
+ /// Gets a List<string> from the pool or creates a new one.
+ ///
+ public static List GetStringList()
+ {
+ if (_stringListPool.TryTake(out var list))
+ {
+ return list;
+ }
+
+ return new List();
+ }
+
+ ///
+ /// Gets a List<string> from the pool with specified capacity.
+ ///
+ public static List GetStringList(int capacity)
+ {
+ if (_stringListPool.TryTake(out var list))
+ {
+ if (list.Capacity < capacity)
+ {
+ list.Capacity = capacity;
+ }
+ return list;
+ }
+
+ return new List(capacity);
+ }
+
+ ///
+ /// Returns a List<string> to the pool after clearing it.
+ ///
+ public static void Return(List list)
+ {
+ if (list is null || _stringListPool.Count >= MaxPoolSize)
+ {
+ return;
+ }
+
+ list.Clear();
+ _stringListPool.Add(list);
+ }
+
+ ///
+ /// Gets a Dictionary<int, List<string>> from the pool.
+ ///
+ public static Dictionary> GetIntListDictionary()
+ {
+ if (_intListDictionaryPool.TryTake(out var dict))
+ {
+ return dict;
+ }
+
+ return new Dictionary>();
+ }
+
+ ///
+ /// Returns a Dictionary<int, List<string>> to the pool.
+ /// The inner lists are also returned to their respective pools.
+ ///
+ public static void Return(Dictionary> dict)
+ {
+ if (dict is null || _intListDictionaryPool.Count >= MaxPoolSize)
+ {
+ return;
+ }
+
+ // Return inner lists to their pool
+ foreach (var list in dict.Values)
+ {
+ Return(list);
+ }
+
+ dict.Clear();
+ _intListDictionaryPool.Add(dict);
+ }
+
+ ///
+ /// A scope that automatically returns a HashSet to the pool when disposed.
+ ///
+ public readonly struct HashSetScope(HashSet set) : IDisposable
+ {
+ public HashSet Set { get; } = set;
+
+ public void Dispose()
+ {
+ Return(Set);
+ }
+ }
+
+ ///
+ /// A scope that automatically returns a List to the pool when disposed.
+ ///
+ public readonly struct ListScope(List list) : IDisposable
+ {
+ public List List { get; } = list;
+
+ public void Dispose()
+ {
+ Return(List);
+ }
+ }
+
+ ///
+ /// Creates a scoped HashSet that is automatically returned to the pool.
+ ///
+ public static HashSetScope GetScopedStringHashSet()
+ {
+ return new HashSetScope(GetStringHashSet());
+ }
+
+ ///
+ /// Creates a scoped List that is automatically returned to the pool.
+ ///
+ public static ListScope GetScopedStringList()
+ {
+ return new ListScope(GetStringList());
+ }
+
+ ///
+ /// Creates a scoped List with capacity that is automatically returned to the pool.
+ ///
+ public static ListScope GetScopedStringList(int capacity)
+ {
+ return new ListScope(GetStringList(capacity));
+ }
+}
diff --git a/ManagedCode.Orleans.SignalR.Core/Helpers/PartitionHelper.cs b/ManagedCode.Orleans.SignalR.Core/Helpers/PartitionHelper.cs
index 29b434b..8d625fb 100644
--- a/ManagedCode.Orleans.SignalR.Core/Helpers/PartitionHelper.cs
+++ b/ManagedCode.Orleans.SignalR.Core/Helpers/PartitionHelper.cs
@@ -1,8 +1,12 @@
using System;
+using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
+using System.Globalization;
+using System.IO.Hashing;
using System.Linq;
-using System.Security.Cryptography;
+using System.Numerics;
+using System.Runtime.CompilerServices;
using System.Text;
namespace ManagedCode.Orleans.SignalR.Core.Helpers;
@@ -10,65 +14,80 @@ namespace ManagedCode.Orleans.SignalR.Core.Helpers;
public static class PartitionHelper
{
private const int VirtualNodesPerPartition = 150; // Number of virtual nodes per physical partition
- private static readonly ConcurrentDictionary RingCache = new();
+ private const int MaxStackAllocSize = 256; // Max bytes for stackalloc
+ private static readonly ConcurrentDictionary<_ringCacheKey, ConsistentHashRing> _ringCache = new();
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int GetPartitionId(string connectionId, uint partitionCount)
{
- if (string.IsNullOrEmpty(connectionId))
- {
- throw new ArgumentException("Connection ID cannot be null or empty", nameof(connectionId));
- }
-
- if (partitionCount <= 0)
- {
- throw new ArgumentException("Partition count must be greater than 0", nameof(partitionCount));
- }
+ ArgumentException.ThrowIfNullOrEmpty(connectionId);
+ ArgumentOutOfRangeException.ThrowIfZero(partitionCount);
- var ring = RingCache.GetOrAdd(new RingCacheKey((int)partitionCount, VirtualNodesPerPartition),
- key => new ConsistentHashRing(key.PartitionCount, key.VirtualNodes));
+ var ring = _ringCache.GetOrAdd(new _ringCacheKey((int)partitionCount, VirtualNodesPerPartition),
+ static key => new ConsistentHashRing(key.PartitionCount, key.VirtualNodes));
return ring.GetPartition(connectionId);
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int GetOptimalPartitionCount(int expectedConnections)
{
return GetOptimalPartitionCount(expectedConnections, 10_000);
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int GetOptimalPartitionCount(int expectedConnections, int connectionsPerPartition)
{
var perPartition = Math.Max(1, connectionsPerPartition);
var partitions = Math.Max(1, (expectedConnections + perPartition - 1) / perPartition);
- return ToPowerOfTwo(partitions);
+ return (int)BitOperations.RoundUpToPowerOf2((uint)partitions);
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int GetOptimalGroupPartitionCount(int expectedGroups)
{
return GetOptimalGroupPartitionCount(expectedGroups, 1_000);
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int GetOptimalGroupPartitionCount(int expectedGroups, int groupsPerPartition)
{
var perPartition = Math.Max(1, groupsPerPartition);
var partitions = Math.Max(1, (expectedGroups + perPartition - 1) / perPartition);
- return ToPowerOfTwo(partitions);
+ return (int)BitOperations.RoundUpToPowerOf2((uint)partitions);
}
- private static int ToPowerOfTwo(int value)
+ ///
+ /// Computes hash using stack allocation for small strings, ArrayPool for larger ones.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ internal static uint ComputeHash(ReadOnlySpan key)
{
- if (value <= 1)
+ var maxByteCount = Encoding.UTF8.GetMaxByteCount(key.Length);
+
+ if (maxByteCount <= MaxStackAllocSize)
{
- return 1;
+ Span buffer = stackalloc byte[maxByteCount];
+ var bytesWritten = Encoding.UTF8.GetBytes(key, buffer);
+ return unchecked((uint)XxHash64.HashToUInt64(buffer[..bytesWritten]));
}
- var power = (int)Math.Ceiling(Math.Log(value, 2));
- return (int)Math.Pow(2, power);
+ var rentedBuffer = ArrayPool.Shared.Rent(maxByteCount);
+ try
+ {
+ var bytesWritten = Encoding.UTF8.GetBytes(key, rentedBuffer);
+ return unchecked((uint)XxHash64.HashToUInt64(rentedBuffer.AsSpan(0, bytesWritten)));
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(rentedBuffer);
+ }
}
- private readonly record struct RingCacheKey(int PartitionCount, int VirtualNodes);
+ private readonly record struct _ringCacheKey(int PartitionCount, int VirtualNodes);
}
-public class ConsistentHashRing
+public sealed class ConsistentHashRing
{
private readonly uint[] _keys;
private readonly int[] _partitions;
@@ -76,10 +95,7 @@ public class ConsistentHashRing
public ConsistentHashRing(int partitionCount, int virtualNodes = 150)
{
- if (partitionCount <= 0)
- {
- throw new ArgumentOutOfRangeException(nameof(partitionCount), "Partition count must be greater than zero.");
- }
+ ArgumentOutOfRangeException.ThrowIfNegativeOrZero(partitionCount);
_partitionCount = partitionCount;
@@ -92,12 +108,24 @@ private static SortedList InitializeRing(int partitionCount, int virt
{
var ring = new SortedList(partitionCount * virtualNodes);
+ Span keyBuffer = stackalloc char[64]; // "partition-XXXX-vnode-XXXX" max ~25 chars
+
for (var partition = 0; partition < partitionCount; partition++)
{
for (var vnode = 0; vnode < virtualNodes; vnode++)
{
- var virtualNodeKey = $"partition-{partition}-vnode-{vnode}";
- var hash = GetHash(virtualNodeKey);
+ // Build key without allocation using TryFormat
+ var written = 0;
+ "partition-".AsSpan().CopyTo(keyBuffer);
+ written += 10;
+ partition.TryFormat(keyBuffer[written..], out var partitionChars, default, CultureInfo.InvariantCulture);
+ written += partitionChars;
+ "-vnode-".AsSpan().CopyTo(keyBuffer[written..]);
+ written += 7;
+ vnode.TryFormat(keyBuffer[written..], out var vnodeChars, default, CultureInfo.InvariantCulture);
+ written += vnodeChars;
+
+ var hash = PartitionHelper.ComputeHash(keyBuffer[..written]);
ring[hash] = partition;
}
}
@@ -105,6 +133,7 @@ private static SortedList InitializeRing(int partitionCount, int virt
return ring;
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public int GetPartition(string key)
{
if (_keys.Length == 0)
@@ -112,7 +141,7 @@ public int GetPartition(string key)
return 0;
}
- var hash = GetHash(key);
+ var hash = PartitionHelper.ComputeHash(key.AsSpan());
var index = Array.BinarySearch(_keys, hash);
if (index < 0)
@@ -128,16 +157,9 @@ public int GetPartition(string key)
return _partitions[index];
}
- private static uint GetHash(string key)
- {
- using var md5 = MD5.Create();
- var hash = md5.ComputeHash(Encoding.UTF8.GetBytes(key));
- return BitConverter.ToUInt32(hash, 0);
- }
-
public Dictionary GetDistribution(IEnumerable keys)
{
- var distribution = new Dictionary();
+ var distribution = new Dictionary(_partitionCount);
for (var i = 0; i < _partitionCount; i++)
{
distribution[i] = 0;
diff --git a/ManagedCode.Orleans.SignalR.Core/Helpers/RetryHelper.cs b/ManagedCode.Orleans.SignalR.Core/Helpers/RetryHelper.cs
new file mode 100644
index 0000000..123df24
--- /dev/null
+++ b/ManagedCode.Orleans.SignalR.Core/Helpers/RetryHelper.cs
@@ -0,0 +1,225 @@
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Orleans.Runtime;
+
+namespace ManagedCode.Orleans.SignalR.Core.Helpers;
+
+///
+/// Provides retry functionality with exponential backoff for transient failures.
+///
+public static class RetryHelper
+{
+ ///
+ /// Default configuration for retry operations.
+ ///
+ public static readonly RetryPolicy DefaultPolicy = new(
+ maxAttempts: 3,
+ initialDelay: TimeSpan.FromMilliseconds(100),
+ maxDelay: TimeSpan.FromSeconds(5),
+ exponentialBase: 2.0);
+
+ ///
+ /// Executes an action with retry logic using exponential backoff.
+ ///
+ public static async Task ExecuteWithRetryAsync(
+ Func action,
+ RetryPolicy? policy = null,
+ CancellationToken cancellationToken = default)
+ {
+ policy ??= DefaultPolicy;
+
+ var attempt = 0;
+ var delay = policy.InitialDelay;
+
+ while (true)
+ {
+ try
+ {
+ await action();
+ return;
+ }
+ catch (Exception ex) when (IsTransient(ex) && attempt < policy.MaxAttempts - 1)
+ {
+ attempt++;
+ await Task.Delay(delay, cancellationToken);
+ delay = CalculateNextDelay(delay, policy);
+ }
+ }
+ }
+
+ ///
+ /// Executes a function with retry logic using exponential backoff.
+ ///
+ public static async Task ExecuteWithRetryAsync(
+ Func> func,
+ RetryPolicy? policy = null,
+ CancellationToken cancellationToken = default)
+ {
+ policy ??= DefaultPolicy;
+
+ var attempt = 0;
+ var delay = policy.InitialDelay;
+
+ while (true)
+ {
+ try
+ {
+ return await func();
+ }
+ catch (Exception ex) when (IsTransient(ex) && attempt < policy.MaxAttempts - 1)
+ {
+ attempt++;
+ await Task.Delay(delay, cancellationToken);
+ delay = CalculateNextDelay(delay, policy);
+ }
+ }
+ }
+
+ ///
+ /// Executes a grain call with retry logic, handling Orleans-specific transient failures.
+ ///
+ public static async Task ExecuteGrainCallAsync(
+ Func grainCall,
+ RetryPolicy? policy = null,
+ CancellationToken cancellationToken = default)
+ {
+ policy ??= DefaultPolicy;
+
+ var attempt = 0;
+ var delay = policy.InitialDelay;
+
+ while (true)
+ {
+ try
+ {
+ await grainCall();
+ return;
+ }
+ catch (Exception ex) when (IsOrleansTransient(ex) && attempt < policy.MaxAttempts - 1)
+ {
+ attempt++;
+ await Task.Delay(delay, cancellationToken);
+ delay = CalculateNextDelay(delay, policy);
+ }
+ }
+ }
+
+ ///
+ /// Executes a grain call with retry logic and returns a result.
+ ///
+ public static async Task ExecuteGrainCallAsync(
+ Func> grainCall,
+ RetryPolicy? policy = null,
+ CancellationToken cancellationToken = default)
+ {
+ policy ??= DefaultPolicy;
+
+ var attempt = 0;
+ var delay = policy.InitialDelay;
+
+ while (true)
+ {
+ try
+ {
+ return await grainCall();
+ }
+ catch (Exception ex) when (IsOrleansTransient(ex) && attempt < policy.MaxAttempts - 1)
+ {
+ attempt++;
+ await Task.Delay(delay, cancellationToken);
+ delay = CalculateNextDelay(delay, policy);
+ }
+ }
+ }
+
+ private static TimeSpan CalculateNextDelay(TimeSpan currentDelay, RetryPolicy policy)
+ {
+ // Calculate next delay with exponential backoff
+ var nextDelay = TimeSpan.FromTicks((long)(currentDelay.Ticks * policy.ExponentialBase));
+
+ // Add jitter (±10%) to prevent thundering herd
+ var jitterRange = nextDelay.Ticks / 10;
+ var jitter = Random.Shared.NextInt64(-jitterRange, jitterRange);
+ nextDelay = TimeSpan.FromTicks(nextDelay.Ticks + jitter);
+
+ // Ensure we don't exceed max delay
+ return nextDelay > policy.MaxDelay ? policy.MaxDelay : nextDelay;
+ }
+
+ private static bool IsTransient(Exception ex)
+ {
+ return ex is TimeoutException
+ or TaskCanceledException
+ or OperationCanceledException
+ or OrleansException;
+ }
+
+ private static bool IsOrleansTransient(Exception ex)
+ {
+ // Handle Orleans-specific transient exceptions
+ return ex is TimeoutException
+ or TaskCanceledException
+ or OrleansMessageRejectionException
+ or SiloUnavailableException
+ or GatewayTooBusyException;
+ }
+}
+
+///
+/// Configuration for retry operations.
+///
+public sealed class RetryPolicy
+{
+ ///
+ /// Maximum number of retry attempts.
+ ///
+ public int MaxAttempts { get; }
+
+ ///
+ /// Initial delay between retries.
+ ///
+ public TimeSpan InitialDelay { get; }
+
+ ///
+ /// Maximum delay between retries.
+ ///
+ public TimeSpan MaxDelay { get; }
+
+ ///
+ /// Base for exponential backoff calculation.
+ ///
+ public double ExponentialBase { get; }
+
+ public RetryPolicy(int maxAttempts, TimeSpan initialDelay, TimeSpan maxDelay, double exponentialBase = 2.0)
+ {
+ MaxAttempts = Math.Max(1, maxAttempts);
+ InitialDelay = initialDelay > TimeSpan.Zero ? initialDelay : TimeSpan.FromMilliseconds(100);
+ MaxDelay = maxDelay > InitialDelay ? maxDelay : TimeSpan.FromSeconds(30);
+ ExponentialBase = Math.Max(1.1, exponentialBase);
+ }
+
+ ///
+ /// Creates a policy optimized for fast operations.
+ ///
+ public static RetryPolicy Fast => new(
+ maxAttempts: 3,
+ initialDelay: TimeSpan.FromMilliseconds(50),
+ maxDelay: TimeSpan.FromMilliseconds(500));
+
+ ///
+ /// Creates a policy for slow operations with longer delays.
+ ///
+ public static RetryPolicy Slow => new(
+ maxAttempts: 5,
+ initialDelay: TimeSpan.FromMilliseconds(500),
+ maxDelay: TimeSpan.FromSeconds(30));
+
+ ///
+ /// Creates a policy for aggressive retrying of critical operations.
+ ///
+ public static RetryPolicy Aggressive => new(
+ maxAttempts: 10,
+ initialDelay: TimeSpan.FromMilliseconds(100),
+ maxDelay: TimeSpan.FromSeconds(60));
+}
diff --git a/ManagedCode.Orleans.SignalR.Core/Helpers/TimeIntervalHelper.cs b/ManagedCode.Orleans.SignalR.Core/Helpers/TimeIntervalHelper.cs
index bd1299a..5abd333 100644
--- a/ManagedCode.Orleans.SignalR.Core/Helpers/TimeIntervalHelper.cs
+++ b/ManagedCode.Orleans.SignalR.Core/Helpers/TimeIntervalHelper.cs
@@ -1,8 +1,8 @@
using System;
+using System.Threading;
using ManagedCode.Orleans.SignalR.Core.Config;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Options;
-using System.Threading;
namespace ManagedCode.Orleans.SignalR.Core.Helpers;
diff --git a/ManagedCode.Orleans.SignalR.Core/HubContext/TypedClientBuilder.cs b/ManagedCode.Orleans.SignalR.Core/HubContext/TypedClientBuilder.cs
index 4310918..8c71bcc 100644
--- a/ManagedCode.Orleans.SignalR.Core/HubContext/TypedClientBuilder.cs
+++ b/ManagedCode.Orleans.SignalR.Core/HubContext/TypedClientBuilder.cs
@@ -16,12 +16,12 @@ internal static class TypedClientBuilder
// There is one static instance of _builder per T
private static readonly Lazy> _builder = new(GenerateClientBuilder);
- private static readonly PropertyInfo CancellationTokenNoneProperty =
+ private static readonly PropertyInfo _cancellationTokenNoneProperty =
typeof(CancellationToken).GetProperty("None", BindingFlags.Public | BindingFlags.Static)!;
- private static readonly ConstructorInfo ObjectConstructor = typeof(object).GetConstructors().Single();
+ private static readonly ConstructorInfo _objectConstructor = typeof(object).GetConstructors().Single();
- private static readonly Type[] ParameterTypes = [typeof(IClientProxy)];
+ private static readonly Type[] _parameterTypes = [typeof(IClientProxy)];
public static T Build(IClientProxy proxy)
{
@@ -89,13 +89,13 @@ private static IEnumerable GetAllInterfaceMethods(Type interfaceType
private static ConstructorInfo BuildConstructor(TypeBuilder type, FieldInfo proxyField)
{
- var ctor = type.DefineConstructor(MethodAttributes.Public, CallingConventions.Standard, ParameterTypes);
+ var ctor = type.DefineConstructor(MethodAttributes.Public, CallingConventions.Standard, _parameterTypes);
var generator = ctor.GetILGenerator();
// Call object constructor
generator.Emit(OpCodes.Ldarg_0);
- generator.Emit(OpCodes.Call, ObjectConstructor);
+ generator.Emit(OpCodes.Call, _objectConstructor);
// Assign constructor argument to the proxyField
generator.Emit(OpCodes.Ldarg_0); // type
@@ -217,7 +217,7 @@ private static void BuildMethod(TypeBuilder type, MethodInfo interfaceMethodInfo
else
{
// Get 'CancellationToken.None' and put it on the stack, for when method does not have CancellationToken
- generator.Emit(OpCodes.Call, CancellationTokenNoneProperty.GetMethod!);
+ generator.Emit(OpCodes.Call, _cancellationTokenNoneProperty.GetMethod!);
}
// Send!
@@ -229,7 +229,7 @@ private static void BuildMethod(TypeBuilder type, MethodInfo interfaceMethodInfo
private static void BuildFactoryMethod(TypeBuilder type, ConstructorInfo ctor)
{
var method = type.DefineMethod(nameof(Build), MethodAttributes.Public | MethodAttributes.Static,
- CallingConventions.Standard, typeof(T), ParameterTypes);
+ CallingConventions.Standard, typeof(T), _parameterTypes);
var generator = method.GetILGenerator();
diff --git a/ManagedCode.Orleans.SignalR.Core/Interfaces/ISignalRConnectionCoordinatorGrain.cs b/ManagedCode.Orleans.SignalR.Core/Interfaces/ISignalRConnectionCoordinatorGrain.cs
index 6bc7884..52822cc 100644
--- a/ManagedCode.Orleans.SignalR.Core/Interfaces/ISignalRConnectionCoordinatorGrain.cs
+++ b/ManagedCode.Orleans.SignalR.Core/Interfaces/ISignalRConnectionCoordinatorGrain.cs
@@ -11,7 +11,6 @@ public interface ISignalRConnectionCoordinatorGrain : IGrainWithStringKey
[AlwaysInterleave]
Task GetPartitionCount();
- [ReadOnly]
[AlwaysInterleave]
Task GetPartitionForConnection(string connectionId);
diff --git a/ManagedCode.Orleans.SignalR.Core/Models/ConnectionCoordinatorState.cs b/ManagedCode.Orleans.SignalR.Core/Models/ConnectionCoordinatorState.cs
index aa1662c..fd1f300 100644
--- a/ManagedCode.Orleans.SignalR.Core/Models/ConnectionCoordinatorState.cs
+++ b/ManagedCode.Orleans.SignalR.Core/Models/ConnectionCoordinatorState.cs
@@ -8,8 +8,14 @@ namespace ManagedCode.Orleans.SignalR.Core.Models;
public sealed class ConnectionCoordinatorState
{
[Id(0)]
- public Dictionary ConnectionPartitions { get; set; } = new(StringComparer.Ordinal);
+ public Dictionary ConnectionPartitions { get; set; } = new(StringComparer.Ordinal);
[Id(1)]
public int CurrentPartitionCount { get; set; }
+
+ ///
+ /// Epoch increments each time partition count changes, enabling detection of stale assignments.
+ ///
+ [Id(2)]
+ public int PartitionEpoch { get; set; } = 1;
}
diff --git a/ManagedCode.Orleans.SignalR.Core/Models/GroupCoordinatorState.cs b/ManagedCode.Orleans.SignalR.Core/Models/GroupCoordinatorState.cs
index 13b0349..900cbf2 100644
--- a/ManagedCode.Orleans.SignalR.Core/Models/GroupCoordinatorState.cs
+++ b/ManagedCode.Orleans.SignalR.Core/Models/GroupCoordinatorState.cs
@@ -8,11 +8,17 @@ namespace ManagedCode.Orleans.SignalR.Core.Models;
public sealed class GroupCoordinatorState
{
[Id(0)]
- public Dictionary GroupPartitions { get; set; } = new(StringComparer.Ordinal);
+ public Dictionary GroupPartitions { get; set; } = new(StringComparer.Ordinal);
[Id(1)]
public Dictionary GroupMembership { get; set; } = new(StringComparer.Ordinal);
[Id(2)]
public int CurrentPartitionCount { get; set; }
+
+ ///
+ /// Epoch increments each time partition count changes, enabling detection of stale assignments.
+ ///
+ [Id(3)]
+ public int PartitionEpoch { get; set; } = 1;
}
diff --git a/ManagedCode.Orleans.SignalR.Core/Models/PartitionAssignment.cs b/ManagedCode.Orleans.SignalR.Core/Models/PartitionAssignment.cs
new file mode 100644
index 0000000..d9e132f
--- /dev/null
+++ b/ManagedCode.Orleans.SignalR.Core/Models/PartitionAssignment.cs
@@ -0,0 +1,18 @@
+using Orleans;
+
+namespace ManagedCode.Orleans.SignalR.Core.Models;
+
+///
+/// Represents a partition assignment with epoch tracking for consistency during scaling.
+///
+[GenerateSerializer]
+[Immutable]
+public readonly record struct PartitionAssignment(
+ [property: Id(0)] int PartitionId,
+ [property: Id(1)] int Epoch)
+{
+ ///
+ /// Creates an assignment for the current epoch.
+ ///
+ public static PartitionAssignment Create(int partitionId, int epoch) => new(partitionId, epoch);
+}
diff --git a/ManagedCode.Orleans.SignalR.Core/Properties/launchSettings.json b/ManagedCode.Orleans.SignalR.Core/Properties/launchSettings.json
new file mode 100644
index 0000000..11ab422
--- /dev/null
+++ b/ManagedCode.Orleans.SignalR.Core/Properties/launchSettings.json
@@ -0,0 +1,12 @@
+{
+ "profiles": {
+ "ManagedCode.Orleans.SignalR.Core": {
+ "commandName": "Project",
+ "launchBrowser": true,
+ "environmentVariables": {
+ "ASPNETCORE_ENVIRONMENT": "Development"
+ },
+ "applicationUrl": "https://localhost:56458;http://localhost:56462"
+ }
+ }
+}
\ No newline at end of file
diff --git a/ManagedCode.Orleans.SignalR.Core/SignalR/NameHelperGenerator.cs b/ManagedCode.Orleans.SignalR.Core/SignalR/NameHelperGenerator.cs
index bf33ba0..ee03478 100644
--- a/ManagedCode.Orleans.SignalR.Core/SignalR/NameHelperGenerator.cs
+++ b/ManagedCode.Orleans.SignalR.Core/SignalR/NameHelperGenerator.cs
@@ -1,5 +1,8 @@
-using System.IO.Hashing;
-using System.Text;
+using System;
+using System.Buffers;
+using System.Collections.Concurrent;
+using System.Runtime.CompilerServices;
+using ManagedCode.Orleans.SignalR.Core.Helpers;
using ManagedCode.Orleans.SignalR.Core.Interfaces;
using Orleans;
@@ -7,69 +10,89 @@ namespace ManagedCode.Orleans.SignalR.Core.SignalR;
public static class NameHelperGenerator
{
+ // Cache cleaned type names to avoid repeated allocations
+ private static readonly ConcurrentDictionary _typeNameCache = new();
+
+ // SearchValues for allowed characters (optimized for .NET 8+)
+ private static readonly SearchValues _allowedChars =
+ SearchValues.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-:.");
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRConnectionHolderGrain GetConnectionHolderGrain(IGrainFactory grainFactory)
{
- return grainFactory.GetGrain(CleanString(typeof(THub).FullName!));
+ return grainFactory.GetGrain(GetCleanedTypeName());
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRConnectionHolderGrain GetConnectionHolderGrain(IGrainFactory grainFactory, string hubKey)
{
return grainFactory.GetGrain(CleanString(hubKey));
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRConnectionCoordinatorGrain GetConnectionCoordinatorGrain(IGrainFactory grainFactory)
{
- return grainFactory.GetGrain(CleanString(typeof(THub).FullName!));
+ return grainFactory.GetGrain(GetCleanedTypeName());
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRConnectionPartitionGrain GetConnectionPartitionGrain(IGrainFactory grainFactory, int partitionId)
{
- var key = GetPartitionGrainKey(typeof(THub).FullName!, partitionId, alreadyCleaned: false);
+ var key = GetPartitionGrainKey(partitionId);
return grainFactory.GetGrain(key);
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRConnectionPartitionGrain GetConnectionPartitionGrain(IGrainFactory grainFactory, string hubKey, int partitionId)
{
var key = GetPartitionGrainKey(hubKey, partitionId, alreadyCleaned: true);
return grainFactory.GetGrain(key);
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRInvocationGrain GetInvocationGrain(IGrainFactory grainFactory, string? invocationId)
{
- return grainFactory.GetGrain(CleanString(typeof(THub).FullName + "::" + invocationId ?? "unknown"));
+ var typeName = GetCleanedTypeName();
+ var key = string.Concat(typeName, "::", invocationId ?? "unknown");
+ return grainFactory.GetGrain(key);
}
- // public static ISignalRGroupHolderGrain GetGroupHolderGrain(IGrainFactory grainFactory)
- // {
- // return grainFactory.GetGrain(typeof(THub).FullName);
- // }
-
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRUserGrain GetSignalRUserGrain(IGrainFactory grainFactory, string userId)
{
- return grainFactory.GetGrain(CleanString(typeof(THub).FullName + "::" + userId));
+ var typeName = GetCleanedTypeName();
+ var cleanUserId = CleanString(userId);
+ return grainFactory.GetGrain(string.Concat(typeName, "::", cleanUserId));
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRGroupGrain GetSignalRGroupGrain(IGrainFactory grainFactory, string groupId)
{
- return grainFactory.GetGrain(CleanString(typeof(THub).FullName + "::" + groupId));
+ var typeName = GetCleanedTypeName();
+ var cleanGroupId = CleanString(groupId);
+ return grainFactory.GetGrain(string.Concat(typeName, "::", cleanGroupId));
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRGroupCoordinatorGrain GetGroupCoordinatorGrain(IGrainFactory grainFactory)
{
- return grainFactory.GetGrain(CleanString(typeof(THub).FullName!));
+ return grainFactory.GetGrain(GetCleanedTypeName());
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRGroupCoordinatorGrain GetGroupCoordinatorGrain(IGrainFactory grainFactory, string hubKey)
{
return grainFactory.GetGrain(CleanString(hubKey));
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRGroupPartitionGrain GetGroupPartitionGrain(IGrainFactory grainFactory, int partitionId)
{
- var key = GetPartitionGrainKey(typeof(THub).FullName!, partitionId, alreadyCleaned: false);
+ var key = GetPartitionGrainKey(partitionId);
return grainFactory.GetGrain(key);
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ISignalRGroupPartitionGrain GetGroupPartitionGrain(IGrainFactory grainFactory, string hubKey, int partitionId)
{
var key = GetPartitionGrainKey(hubKey, partitionId, alreadyCleaned: true);
@@ -78,33 +101,73 @@ public static ISignalRGroupPartitionGrain GetGroupPartitionGrain(IGrainFactory g
public static ISignalRConnectionHeartbeatGrain GetConnectionHeartbeatGrain(IGrainFactory grainFactory, string hubKey, string connectionId)
{
- var normalizedConnection = CleanString(connectionId);
- var key = $"{CleanString(hubKey)}::{normalizedConnection}";
- return grainFactory.GetGrain(key);
+ var cleanedHub = CleanString(hubKey);
+ var cleanedConnection = CleanString(connectionId);
+ return grainFactory.GetGrain(string.Concat(cleanedHub, "::", cleanedConnection));
+ }
+
+ ///
+ /// Gets the cached cleaned type name for a hub type.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static string GetCleanedTypeName()
+ {
+ return _typeNameCache.GetOrAdd(typeof(THub), static t => CleanString(t.FullName!));
}
+ ///
+ /// Cleans a string by replacing invalid characters with ':'.
+ /// Uses SearchValues for optimized character lookup and string.Create for allocation-efficient string building.
+ ///
public static string CleanString(string input)
{
- var builder = new StringBuilder();
- foreach (var c in input)
+ if (string.IsNullOrEmpty(input))
{
- if (char.IsLetterOrDigit(c) || c == '-' || c == ':' || c == '.')
- {
- builder.Append(c);
- }
- else
+ return input;
+ }
+
+ // Fast path: check if any characters need replacement
+ var inputSpan = input.AsSpan();
+ var firstInvalidIndex = inputSpan.IndexOfAnyExcept(_allowedChars);
+
+ if (firstInvalidIndex < 0)
+ {
+ // All characters are valid, return original string
+ return input;
+ }
+
+ // Need to clean - use string.Create for efficient allocation
+ return string.Create(input.Length, input, static (span, src) =>
+ {
+ for (var i = 0; i < src.Length; i++)
{
- builder.Append(':');
+ var c = src[i];
+ span[i] = _allowedChars.Contains(c) ? c : ':';
}
- }
- return builder.ToString();
+ });
}
+ ///
+ /// Gets partition grain key using cached type name.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static long GetPartitionGrainKey(int partitionId)
+ {
+ var cleanedName = GetCleanedTypeName();
+ return ComputePartitionKey(cleanedName.AsSpan(), partitionId);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static long GetPartitionGrainKey(string hubIdentity, int partitionId, bool alreadyCleaned)
{
var normalized = alreadyCleaned ? hubIdentity : CleanString(hubIdentity);
- var hubBytes = Encoding.UTF8.GetBytes(normalized);
- var hash = XxHash64.HashToUInt64(hubBytes);
+ return ComputePartitionKey(normalized.AsSpan(), partitionId);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static long ComputePartitionKey(ReadOnlySpan hubIdentity, int partitionId)
+ {
+ var hash = (ulong)PartitionHelper.ComputeHash(hubIdentity);
var composite = (hash << 16) ^ (uint)partitionId;
return unchecked((long)composite);
}
diff --git a/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/ExpiringObserverBuffer.cs b/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/ExpiringObserverBuffer.cs
new file mode 100644
index 0000000..fd8a8ac
--- /dev/null
+++ b/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/ExpiringObserverBuffer.cs
@@ -0,0 +1,257 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using Microsoft.AspNetCore.SignalR.Protocol;
+
+namespace ManagedCode.Orleans.SignalR.Core.SignalR.Observers;
+
+///
+/// Buffers messages for observers in the grace period before hard expiration.
+/// This handles timing edge cases where heartbeats are delayed due to GC pauses,
+/// network latency, or silo overload.
+///
+/// Note: This class is designed to be used within Orleans grains which provide single-threaded
+/// execution guarantees. No explicit locking is required.
+///
+public sealed class ExpiringObserverBuffer(TimeSpan gracePeriod, int maxBufferedMessages)
+{
+ private readonly Dictionary _buffers = new(StringComparer.Ordinal);
+ private readonly TimeSpan _gracePeriod = gracePeriod;
+ private readonly int _maxBufferedMessages = Math.Max(1, maxBufferedMessages);
+
+ ///
+ /// Gets whether the buffer is enabled (grace period > 0).
+ ///
+ public bool IsEnabled => _gracePeriod > TimeSpan.Zero;
+
+ ///
+ /// Starts the grace period for an observer, buffering messages until restored or expired.
+ ///
+ /// The connection ID.
+ /// True if grace period started, false if already in grace period.
+ public bool StartGracePeriod(string connectionId)
+ {
+ if (!IsEnabled)
+ {
+ return false;
+ }
+
+ if (_buffers.ContainsKey(connectionId))
+ {
+ return false; // Already in grace period
+ }
+
+ _buffers[connectionId] = new ObserverBufferState(_gracePeriod, _maxBufferedMessages);
+ return true;
+ }
+
+ ///
+ /// Checks if an observer is in the grace period.
+ ///
+ public bool IsInGracePeriod(string connectionId)
+ {
+ if (!_buffers.TryGetValue(connectionId, out var state))
+ {
+ return false;
+ }
+
+ // Check if grace period has expired
+ if (state.IsExpired)
+ {
+ _buffers.Remove(connectionId);
+ return false;
+ }
+
+ return true;
+ }
+
+ ///
+ /// Buffers a message for an observer in the grace period.
+ ///
+ /// True if buffered, false if not in grace period or buffer full.
+ public bool BufferMessage(string connectionId, HubMessage message)
+ {
+ if (!IsEnabled)
+ {
+ return false;
+ }
+
+ if (!_buffers.TryGetValue(connectionId, out var state))
+ {
+ return false;
+ }
+
+ if (state.IsExpired)
+ {
+ _buffers.Remove(connectionId);
+ return false;
+ }
+
+ return state.AddMessage(message);
+ }
+
+ ///
+ /// Restores an observer from the grace period and returns buffered messages.
+ ///
+ /// The connection ID.
+ /// Buffered messages, or empty if not in grace period.
+ public IReadOnlyList RestoreAndGetMessages(string connectionId)
+ {
+ if (!_buffers.Remove(connectionId, out var state))
+ {
+ return Array.Empty();
+ }
+
+ return state.GetMessages();
+ }
+
+ ///
+ /// Expires an observer's grace period and discards buffered messages.
+ ///
+ /// Number of messages discarded.
+ public int Expire(string connectionId)
+ {
+ if (!_buffers.Remove(connectionId, out var state))
+ {
+ return 0;
+ }
+
+ return state.MessageCount;
+ }
+
+ ///
+ /// Checks and removes expired grace periods.
+ ///
+ /// List of connection IDs that expired.
+ public List CleanupExpired()
+ {
+ var expired = new List();
+
+ foreach (var (connectionId, state) in _buffers)
+ {
+ if (state.IsExpired)
+ {
+ expired.Add(connectionId);
+ }
+ }
+
+ foreach (var connectionId in expired)
+ {
+ _buffers.Remove(connectionId);
+ }
+
+ return expired;
+ }
+
+ ///
+ /// Gets the remaining grace period time for a connection.
+ ///
+ public TimeSpan? GetRemainingGracePeriod(string connectionId)
+ {
+ if (_buffers.TryGetValue(connectionId, out var state) && !state.IsExpired)
+ {
+ return state.RemainingTime;
+ }
+
+ return null;
+ }
+
+ ///
+ /// Gets statistics about the buffer.
+ ///
+ public BufferStatistics GetStatistics()
+ {
+ var stats = new BufferStatistics();
+
+ foreach (var state in _buffers.Values)
+ {
+ if (state.IsExpired)
+ {
+ continue;
+ }
+
+ stats.ObserversInGracePeriod++;
+ stats.TotalBufferedMessages += state.MessageCount;
+ }
+
+ return stats;
+ }
+
+ ///
+ /// Clears all buffers.
+ ///
+ public void Clear()
+ {
+ _buffers.Clear();
+ }
+
+ ///
+ /// Circular buffer state for a single observer, optimized for O(1) enqueue/dequeue.
+ ///
+ private sealed class ObserverBufferState(TimeSpan gracePeriod, int maxMessages)
+ {
+ private readonly long _createdAtTimestamp = Stopwatch.GetTimestamp();
+ private readonly TimeSpan _gracePeriod = gracePeriod;
+ private readonly HubMessage[] _messages = new HubMessage[maxMessages];
+ private int _head; // Index of first (oldest) message
+ public bool IsExpired => Stopwatch.GetElapsedTime(_createdAtTimestamp) >= _gracePeriod;
+
+ public TimeSpan RemainingTime
+ {
+ get
+ {
+ var remaining = _gracePeriod - Stopwatch.GetElapsedTime(_createdAtTimestamp);
+ return remaining > TimeSpan.Zero ? remaining : TimeSpan.Zero;
+ }
+ }
+
+ public int MessageCount { get; private set; } // Number of messages in buffer
+
+ public bool AddMessage(HubMessage message)
+ {
+ if (MessageCount >= _messages.Length)
+ {
+ // Buffer is full - overwrite oldest message (drop oldest)
+ // The head points to the oldest, so we overwrite it and advance head
+ _messages[_head] = message;
+ _head = (_head + 1) % _messages.Length;
+ // MessageCount stays the same since we're replacing
+ }
+ else
+ {
+ // Buffer has space - add at tail position
+ var tail = (_head + MessageCount) % _messages.Length;
+ _messages[tail] = message;
+ MessageCount++;
+ }
+
+ return true;
+ }
+
+ public IReadOnlyList GetMessages()
+ {
+ if (MessageCount == 0)
+ {
+ return Array.Empty();
+ }
+
+ // Return messages in order (oldest to newest)
+ var result = new HubMessage[MessageCount];
+ for (var i = 0; i < MessageCount; i++)
+ {
+ result[i] = _messages[(_head + i) % _messages.Length];
+ }
+
+ return result;
+ }
+ }
+}
+
+///
+/// Statistics about the expiring observer buffer.
+///
+public sealed class BufferStatistics
+{
+ public int ObserversInGracePeriod { get; set; }
+ public int TotalBufferedMessages { get; set; }
+}
diff --git a/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/ObserverCircuitBreaker.cs b/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/ObserverCircuitBreaker.cs
new file mode 100644
index 0000000..2db3181
--- /dev/null
+++ b/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/ObserverCircuitBreaker.cs
@@ -0,0 +1,230 @@
+using System;
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using System.Threading;
+
+namespace ManagedCode.Orleans.SignalR.Core.SignalR.Observers;
+
+///
+/// Circuit breaker states following the standard pattern.
+///
+public enum CircuitState
+{
+ ///
+ /// Circuit is closed, requests flow through normally.
+ ///
+ Closed,
+
+ ///
+ /// Circuit is open, requests are blocked to prevent cascade failures.
+ ///
+ Open,
+
+ ///
+ /// Circuit is testing if the observer has recovered.
+ /// One request is allowed through to test connectivity.
+ ///
+ HalfOpen
+}
+
+///
+/// Circuit breaker for an individual observer to prevent cascade failures.
+/// Thread-safe implementation using lock-free operations where possible.
+///
+public sealed class ObserverCircuitBreaker(int failureThreshold, TimeSpan openDuration, TimeSpan halfOpenTestInterval)
+{
+ private readonly int _failureThreshold = Math.Max(1, failureThreshold);
+ private readonly TimeSpan _openDuration = openDuration;
+ private readonly TimeSpan _halfOpenTestInterval = halfOpenTestInterval;
+
+ private int _failureCount;
+ private int _state = (int)CircuitState.Closed; // CircuitState as int for Interlocked operations
+ private long _lastFailureTimestamp;
+ private long _lastHalfOpenTestTimestamp;
+ private long _openedAtTimestamp;
+ private readonly object _lock = new();
+
+ ///
+ /// Gets the current state of the circuit breaker.
+ ///
+ public CircuitState State
+ {
+ get
+ {
+ var currentState = (CircuitState)Volatile.Read(ref _state);
+
+ // Check if we should transition from Open to HalfOpen
+ if (currentState == CircuitState.Open)
+ {
+ if (Stopwatch.GetElapsedTime(_openedAtTimestamp) >= _openDuration)
+ {
+ TryTransitionToHalfOpen();
+ return (CircuitState)Volatile.Read(ref _state);
+ }
+ }
+
+ return currentState;
+ }
+ }
+
+ ///
+ /// Gets the number of consecutive failures.
+ ///
+ public int FailureCount => Volatile.Read(ref _failureCount);
+
+ ///
+ /// Gets the last exception that caused a failure.
+ ///
+ public Exception? LastException { get; private set; }
+
+ ///
+ /// Gets whether the circuit allows requests through.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public bool AllowRequest()
+ {
+ var currentState = State; // This handles Open -> HalfOpen transition
+
+ switch (currentState)
+ {
+ case CircuitState.Closed:
+ return true;
+
+ case CircuitState.Open:
+ return false;
+
+ case CircuitState.HalfOpen:
+ // In half-open state, allow one test request periodically
+ return ShouldAllowHalfOpenTest();
+
+ default:
+ return false;
+ }
+ }
+
+ ///
+ /// Records a successful operation, potentially closing the circuit.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void RecordSuccess()
+ {
+ var currentState = (CircuitState)Volatile.Read(ref _state);
+
+ if (currentState == CircuitState.HalfOpen)
+ {
+ // Success in half-open state closes the circuit
+ Close();
+ }
+ else if (currentState == CircuitState.Closed)
+ {
+ // Reset failure count on success
+ Interlocked.Exchange(ref _failureCount, 0);
+ LastException = null;
+ }
+ }
+
+ ///
+ /// Records a failed operation, potentially opening the circuit.
+ /// Returns true if the circuit just transitioned to Open state.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public bool RecordFailure(Exception? exception = null)
+ {
+ LastException = exception;
+ _lastFailureTimestamp = Stopwatch.GetTimestamp();
+
+ var currentState = (CircuitState)Volatile.Read(ref _state);
+
+ if (currentState == CircuitState.HalfOpen)
+ {
+ // Failure in half-open state reopens the circuit
+ Open();
+ return true;
+ }
+
+ if (currentState == CircuitState.Closed)
+ {
+ var newCount = Interlocked.Increment(ref _failureCount);
+ if (newCount >= _failureThreshold)
+ {
+ Open();
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Manually opens the circuit.
+ ///
+ public void Open()
+ {
+ lock (_lock)
+ {
+ _state = (int)CircuitState.Open;
+ _openedAtTimestamp = Stopwatch.GetTimestamp();
+ }
+ }
+
+ ///
+ /// Manually closes the circuit and resets failure count.
+ ///
+ public void Close()
+ {
+ lock (_lock)
+ {
+ _state = (int)CircuitState.Closed;
+ _failureCount = 0;
+ LastException = null;
+ }
+ }
+
+ ///
+ /// Resets the circuit breaker to its initial state.
+ ///
+ public void Reset()
+ {
+ lock (_lock)
+ {
+ _state = (int)CircuitState.Closed;
+ _failureCount = 0;
+ LastException = null;
+ _openedAtTimestamp = 0;
+ _lastFailureTimestamp = 0;
+ _lastHalfOpenTestTimestamp = 0;
+ }
+ }
+
+ private void TryTransitionToHalfOpen()
+ {
+ lock (_lock)
+ {
+ if (_state == (int)CircuitState.Open && Stopwatch.GetElapsedTime(_openedAtTimestamp) >= _openDuration)
+ {
+ _state = (int)CircuitState.HalfOpen;
+ _lastHalfOpenTestTimestamp = 0; // Allow immediate test
+ }
+ }
+ }
+
+ private bool ShouldAllowHalfOpenTest()
+ {
+ lock (_lock)
+ {
+ if (_state != (int)CircuitState.HalfOpen)
+ {
+ return false;
+ }
+
+ var now = Stopwatch.GetTimestamp();
+ if (_lastHalfOpenTestTimestamp == 0 || Stopwatch.GetElapsedTime(_lastHalfOpenTestTimestamp, now) >= _halfOpenTestInterval)
+ {
+ _lastHalfOpenTestTimestamp = now;
+ return true;
+ }
+
+ return false;
+ }
+ }
+}
diff --git a/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/ObserverHealthTracker.cs b/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/ObserverHealthTracker.cs
new file mode 100644
index 0000000..8500412
--- /dev/null
+++ b/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/ObserverHealthTracker.cs
@@ -0,0 +1,459 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using Microsoft.AspNetCore.SignalR.Protocol;
+
+namespace ManagedCode.Orleans.SignalR.Core.SignalR.Observers;
+
+///
+/// Tracks observer health by monitoring delivery failures with circuit breaker support.
+/// Observers exceeding the failure threshold have their circuit opened to prevent cascade failures.
+/// Supports graceful expiration with message buffering for timing edge cases.
+///
+/// Note: This class is designed to be used within Orleans grains which provide single-threaded
+/// execution guarantees. No explicit locking is required.
+///
+public sealed class ObserverHealthTracker(
+ int failureThreshold,
+ TimeSpan failureWindow,
+ bool circuitBreakerEnabled = true,
+ TimeSpan? circuitOpenDuration = null,
+ TimeSpan? halfOpenTestInterval = null,
+ TimeSpan? gracePeriod = null,
+ int maxBufferedMessages = 50)
+{
+ private readonly Dictionary _healthStates = new(StringComparer.Ordinal);
+ // Allow 0 to disable health tracking (as documented)
+ private readonly int _failureThreshold = Math.Max(0, failureThreshold);
+ private readonly TimeSpan _failureWindow = failureWindow;
+ private readonly TimeSpan _circuitOpenDuration = circuitOpenDuration ?? TimeSpan.FromSeconds(30);
+ private readonly TimeSpan _halfOpenTestInterval = halfOpenTestInterval ?? TimeSpan.FromSeconds(5);
+ private readonly ExpiringObserverBuffer _gracePeriodBuffer = new(gracePeriod ?? TimeSpan.Zero, maxBufferedMessages);
+
+ ///
+ /// Gets whether health tracking is enabled.
+ ///
+ public bool IsEnabled => _failureThreshold > 0;
+
+ ///
+ /// Gets whether circuit breaker is enabled.
+ ///
+ public bool CircuitBreakerEnabled { get; } = circuitBreakerEnabled;
+
+ ///
+ /// Gets whether grace period buffering is enabled.
+ ///
+ public bool GracePeriodEnabled => _gracePeriodBuffer.IsEnabled;
+
+ ///
+ /// Records a successful delivery to an observer, resetting its failure count and closing circuit.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void RecordSuccess(string connectionId)
+ {
+ if (!IsEnabled)
+ {
+ return;
+ }
+
+ if (_healthStates.TryGetValue(connectionId, out var state))
+ {
+ state.RecordSuccess();
+ }
+ }
+
+ ///
+ /// Records a delivery failure for an observer.
+ /// Returns a result indicating whether the observer is dead or circuit is open.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public FailureResult RecordFailure(string connectionId, Exception? exception = null)
+ {
+ if (!IsEnabled)
+ {
+ return FailureResult.Healthy;
+ }
+
+ if (!_healthStates.TryGetValue(connectionId, out var state))
+ {
+ state = new ObserverHealthState(
+ _failureWindow,
+ CircuitBreakerEnabled,
+ _failureThreshold,
+ _circuitOpenDuration,
+ _halfOpenTestInterval);
+ _healthStates[connectionId] = state;
+ }
+
+ return state.RecordFailure(exception);
+ }
+
+ ///
+ /// Checks if an observer allows requests (healthy and circuit not open).
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public bool AllowRequest(string connectionId)
+ {
+ if (!IsEnabled)
+ {
+ return true;
+ }
+
+ if (!_healthStates.TryGetValue(connectionId, out var state))
+ {
+ return true;
+ }
+
+ return state.AllowRequest();
+ }
+
+ ///
+ /// Checks if an observer is healthy (not exceeding failure threshold).
+ /// Note: Use AllowRequest() for circuit breaker awareness.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public bool IsHealthy(string connectionId)
+ {
+ if (!IsEnabled)
+ {
+ return true;
+ }
+
+ if (!_healthStates.TryGetValue(connectionId, out var state))
+ {
+ return true;
+ }
+
+ return state.IsHealthy;
+ }
+
+ ///
+ /// Gets the circuit breaker state for a connection.
+ ///
+ public CircuitState GetCircuitState(string connectionId)
+ {
+ if (_healthStates.TryGetValue(connectionId, out var state))
+ {
+ return state.CircuitState;
+ }
+
+ return CircuitState.Closed;
+ }
+
+ ///
+ /// Gets the current failure count for an observer.
+ ///
+ public int GetFailureCount(string connectionId)
+ {
+ if (_healthStates.TryGetValue(connectionId, out var state))
+ {
+ return state.FailureCount;
+ }
+
+ return 0;
+ }
+
+ ///
+ /// Removes health tracking state for a connection.
+ ///
+ public void RemoveConnection(string connectionId)
+ {
+ _healthStates.Remove(connectionId);
+ _gracePeriodBuffer.Expire(connectionId);
+ }
+
+ ///
+ /// Starts a grace period for an observer, allowing message buffering until restored or expired.
+ /// Call this when an observer fails but might recover (e.g., heartbeat timeout).
+ ///
+ /// True if grace period started, false if already in grace period or disabled.
+ public bool StartGracePeriod(string connectionId)
+ {
+ return _gracePeriodBuffer.StartGracePeriod(connectionId);
+ }
+
+ ///
+ /// Checks if an observer is currently in the grace period.
+ ///
+ public bool IsInGracePeriod(string connectionId)
+ {
+ return _gracePeriodBuffer.IsInGracePeriod(connectionId);
+ }
+
+ ///
+ /// Buffers a message for an observer that is in the grace period.
+ ///
+ /// True if buffered, false if not in grace period or buffer full.
+ public bool BufferMessage(string connectionId, HubMessage message)
+ {
+ return _gracePeriodBuffer.BufferMessage(connectionId, message);
+ }
+
+ ///
+ /// Restores an observer from the grace period, returning any buffered messages.
+ /// Call this when an observer reconnects or sends a heartbeat during the grace period.
+ ///
+ public IReadOnlyList RestoreFromGracePeriod(string connectionId)
+ {
+ var messages = _gracePeriodBuffer.RestoreAndGetMessages(connectionId);
+
+ // Also reset health state since the observer recovered
+ if (_healthStates.TryGetValue(connectionId, out var state))
+ {
+ state.RecordSuccess();
+ }
+
+ return messages;
+ }
+
+ ///
+ /// Gets the remaining grace period time for a connection.
+ ///
+ public TimeSpan? GetRemainingGracePeriod(string connectionId)
+ {
+ return _gracePeriodBuffer.GetRemainingGracePeriod(connectionId);
+ }
+
+ ///
+ /// Cleans up expired grace periods and returns the connection IDs that expired.
+ ///
+ public List CleanupExpiredGracePeriods()
+ {
+ return _gracePeriodBuffer.CleanupExpired();
+ }
+
+ ///
+ /// Clears all health tracking state.
+ ///
+ public void Clear()
+ {
+ _healthStates.Clear();
+ _gracePeriodBuffer.Clear();
+ }
+
+ ///
+ /// Gets all connection IDs that have exceeded the failure threshold (dead observers).
+ ///
+ public List GetDeadObservers()
+ {
+ var dead = new List();
+
+ foreach (var (connectionId, state) in _healthStates)
+ {
+ if (state.IsDead)
+ {
+ dead.Add(connectionId);
+ }
+ }
+
+ return dead;
+ }
+
+ ///
+ /// Gets all connection IDs with open circuits.
+ ///
+ public List GetOpenCircuits()
+ {
+ var open = new List();
+
+ foreach (var (connectionId, state) in _healthStates)
+ {
+ if (state.CircuitState == CircuitState.Open)
+ {
+ open.Add(connectionId);
+ }
+ }
+
+ return open;
+ }
+
+ ///
+ /// Gets statistics about observer health.
+ ///
+ public HealthStatistics GetStatistics()
+ {
+ var stats = new HealthStatistics();
+
+ foreach (var state in _healthStates.Values)
+ {
+ stats.TotalTracked++;
+
+ switch (state.CircuitState)
+ {
+ case CircuitState.Closed:
+ stats.ClosedCircuits++;
+ break;
+ case CircuitState.Open:
+ stats.OpenCircuits++;
+ break;
+ case CircuitState.HalfOpen:
+ stats.HalfOpenCircuits++;
+ break;
+ }
+
+ if (state.IsDead)
+ {
+ stats.DeadObservers++;
+ }
+ }
+
+ // Add grace period stats
+ var bufferStats = _gracePeriodBuffer.GetStatistics();
+ stats.ObserversInGracePeriod = bufferStats.ObserversInGracePeriod;
+ stats.TotalBufferedMessages = bufferStats.TotalBufferedMessages;
+
+ return stats;
+ }
+
+ private sealed class ObserverHealthState
+ {
+ private readonly TimeSpan _failureWindow;
+ private readonly bool _circuitBreakerEnabled;
+ private readonly int _failureThreshold;
+ private readonly List _failureTimestamps = new();
+ private readonly ObserverCircuitBreaker? _circuitBreaker;
+
+ public ObserverHealthState(
+ TimeSpan failureWindow,
+ bool circuitBreakerEnabled,
+ int failureThreshold,
+ TimeSpan circuitOpenDuration,
+ TimeSpan halfOpenTestInterval)
+ {
+ _failureWindow = failureWindow;
+ _circuitBreakerEnabled = circuitBreakerEnabled;
+ _failureThreshold = failureThreshold;
+
+ if (circuitBreakerEnabled)
+ {
+ _circuitBreaker = new ObserverCircuitBreaker(
+ failureThreshold,
+ circuitOpenDuration,
+ halfOpenTestInterval);
+ }
+ }
+
+ public int FailureCount
+ {
+ get
+ {
+ PruneOldFailures();
+ return _failureTimestamps.Count;
+ }
+ }
+
+ public bool IsHealthy => !IsDead && FailureCount < _failureThreshold;
+
+ public bool IsDead { get; private set; }
+
+ public CircuitState CircuitState => _circuitBreaker?.State ?? CircuitState.Closed;
+
+ public Exception? LastException { get; private set; }
+
+ public bool AllowRequest()
+ {
+ if (IsDead)
+ {
+ return false;
+ }
+
+ if (_circuitBreaker is not null)
+ {
+ return _circuitBreaker.AllowRequest();
+ }
+
+ return IsHealthy;
+ }
+
+ public FailureResult RecordFailure(Exception? exception)
+ {
+ PruneOldFailures();
+ _failureTimestamps.Add(Stopwatch.GetTimestamp());
+ LastException = exception;
+
+ var failureCount = _failureTimestamps.Count;
+ var circuitOpened = _circuitBreaker?.RecordFailure(exception) ?? false;
+
+ if (failureCount >= _failureThreshold)
+ {
+ IsDead = true;
+ return FailureResult.Dead;
+ }
+
+ if (circuitOpened)
+ {
+ return FailureResult.CircuitOpened;
+ }
+
+ return FailureResult.Healthy;
+ }
+
+ public void RecordSuccess()
+ {
+ _failureTimestamps.Clear();
+ LastException = null;
+ _circuitBreaker?.RecordSuccess();
+
+ // Allow recovery from dead state if circuit breaker succeeds in half-open
+ if (IsDead && _circuitBreaker?.State == CircuitState.Closed)
+ {
+ IsDead = false;
+ }
+ }
+
+ public void Reset()
+ {
+ _failureTimestamps.Clear();
+ LastException = null;
+ IsDead = false;
+ _circuitBreaker?.Reset();
+ }
+
+ private void PruneOldFailures()
+ {
+ if (_failureTimestamps.Count == 0)
+ {
+ return;
+ }
+
+ var now = Stopwatch.GetTimestamp();
+ _failureTimestamps.RemoveAll(t => Stopwatch.GetElapsedTime(t, now) >= _failureWindow);
+ }
+ }
+}
+
+///
+/// Result of recording a failure.
+///
+public enum FailureResult
+{
+ ///
+ /// Observer is still healthy, failure recorded but below threshold.
+ ///
+ Healthy,
+
+ ///
+ /// Circuit breaker opened due to this failure.
+ ///
+ CircuitOpened,
+
+ ///
+ /// Observer exceeded failure threshold and is marked dead.
+ ///
+ Dead
+}
+
+///
+/// Statistics about observer health tracking.
+///
+public sealed class HealthStatistics
+{
+ public int TotalTracked { get; set; }
+ public int ClosedCircuits { get; set; }
+ public int OpenCircuits { get; set; }
+ public int HalfOpenCircuits { get; set; }
+ public int DeadObservers { get; set; }
+ public int ObserversInGracePeriod { get; set; }
+ public int TotalBufferedMessages { get; set; }
+}
diff --git a/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/Subscription.cs b/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/Subscription.cs
index 468e480..a092f29 100644
--- a/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/Subscription.cs
+++ b/ManagedCode.Orleans.SignalR.Core/SignalR/Observers/Subscription.cs
@@ -6,17 +6,12 @@
namespace ManagedCode.Orleans.SignalR.Core.SignalR.Observers;
-public class Subscription(SignalRObserver observer) : IDisposable
+public sealed class Subscription(SignalRObserver observer) : IDisposable
{
private readonly HashSet _grains = new();
private readonly HashSet _heartbeatGrainIds = new();
private bool _disposed;
- ~Subscription()
- {
- Dispose();
- }
-
public ISignalRObserver Reference { get; private set; } = default!;
public string? HubKey { get; private set; }
@@ -56,6 +51,12 @@ public void RemoveGrain(IObserverConnectionManager grain)
_heartbeatGrainIds.Remove(((GrainReference)grain).GrainId);
}
+ public void ClearGrains()
+ {
+ _grains.Clear();
+ _heartbeatGrainIds.Clear();
+ }
+
public void SetReference(ISignalRObserver reference)
{
Reference = reference;
diff --git a/ManagedCode.Orleans.SignalR.Core/SignalR/OrleansHubLifetimeManager.cs b/ManagedCode.Orleans.SignalR.Core/SignalR/OrleansHubLifetimeManager.cs
index f2a1db4..468680d 100644
--- a/ManagedCode.Orleans.SignalR.Core/SignalR/OrleansHubLifetimeManager.cs
+++ b/ManagedCode.Orleans.SignalR.Core/SignalR/OrleansHubLifetimeManager.cs
@@ -17,6 +17,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Orleans;
+using Orleans.Runtime;
namespace ManagedCode.Orleans.SignalR.Core.SignalR;
@@ -51,31 +52,60 @@ public override async Task OnConnectedAsync(HubConnectionContext connection)
var usePartitions = _orleansSignalOptions.Value.ConnectionPartitionCount > 1;
var partitionId = 0;
- if (usePartitions)
+ // Retry logic for silo restart scenarios where grain directory has stale entries
+ const int maxRetries = 3;
+ for (var attempt = 1; attempt <= maxRetries; attempt++)
{
- var coordinatorGrain = NameHelperGenerator.GetConnectionCoordinatorGrain(_clusterClient);
- partitionId = await coordinatorGrain.GetPartitionForConnection(connection.ConnectionId);
- var partitionGrain = NameHelperGenerator.GetConnectionPartitionGrain(_clusterClient, partitionId);
- subscription.AddGrain(partitionGrain);
- await partitionGrain.AddConnection(connection.ConnectionId, subscription.Reference);
- await partitionGrain.Ping(subscription.Reference);
- }
- else
- {
- var connectionHolderGrain = NameHelperGenerator.GetConnectionHolderGrain(_clusterClient);
- subscription.AddGrain(connectionHolderGrain);
- await connectionHolderGrain.AddConnection(connection.ConnectionId, subscription.Reference);
- await connectionHolderGrain.Ping(subscription.Reference);
+ try
+ {
+ if (usePartitions)
+ {
+ var coordinatorGrain = NameHelperGenerator.GetConnectionCoordinatorGrain(_clusterClient);
+ partitionId = await coordinatorGrain.GetPartitionForConnection(connection.ConnectionId);
+ var partitionGrain = NameHelperGenerator.GetConnectionPartitionGrain(_clusterClient, partitionId);
+ subscription.AddGrain(partitionGrain);
+ await partitionGrain.AddConnection(connection.ConnectionId, subscription.Reference);
+ await partitionGrain.Ping(subscription.Reference);
+ }
+ else
+ {
+ var connectionHolderGrain = NameHelperGenerator.GetConnectionHolderGrain(_clusterClient);
+ subscription.AddGrain(connectionHolderGrain);
+ await connectionHolderGrain.AddConnection(connection.ConnectionId, subscription.Reference);
+ await connectionHolderGrain.Ping(subscription.Reference);
+ }
+
+ // Success - break out of retry loop
+ break;
+ }
+ catch (OrleansMessageRejectionException ex) when (attempt < maxRetries)
+ {
+ // Silo was restarted - grain directory has stale entries
+ // Wait briefly and retry as the new silo should activate fresh grains
+ _logger.LogWarning(ex,
+ "Grain call failed on attempt {Attempt}/{MaxRetries} for connection {ConnectionId}, retrying after delay",
+ attempt, maxRetries, connection.ConnectionId);
+ await Task.Delay(100 * attempt); // Exponential backoff: 100ms, 200ms
+ subscription.ClearGrains();
+ }
}
subscription.SetConnectionMetadata(hubKey, usePartitions, partitionId);
if (!string.IsNullOrEmpty(connection.UserIdentifier))
{
- var userGrain = NameHelperGenerator.GetSignalRUserGrain(_clusterClient, connection.UserIdentifier!);
- subscription.AddGrain(userGrain);
- await userGrain.AddConnection(connection.ConnectionId, subscription.Reference);
- _ = Task.Run(userGrain.RequestMessage);
+ try
+ {
+ var userGrain = NameHelperGenerator.GetSignalRUserGrain(_clusterClient, connection.UserIdentifier!);
+ subscription.AddGrain(userGrain);
+ await userGrain.AddConnection(connection.ConnectionId, subscription.Reference);
+ _ = Task.Run(userGrain.RequestMessage);
+ }
+ catch (OrleansMessageRejectionException ex)
+ {
+ _logger.LogWarning(ex, "Failed to register user grain for connection {ConnectionId}", connection.ConnectionId);
+ // Continue - connection can still work without user-specific messaging
+ }
}
await UpdateConnectionHeartbeatAsync(connection.ConnectionId, subscription);
@@ -89,30 +119,65 @@ public override async Task OnDisconnectedAsync(HubConnectionContext connection)
if (_orleansSignalOptions.Value.KeepEachConnectionAlive)
{
- var hubKey = NameHelperGenerator.CleanString(typeof(THub).FullName!);
- var heartbeatGrain = NameHelperGenerator.GetConnectionHeartbeatGrain(_clusterClient, hubKey, connection.ConnectionId);
- await heartbeatGrain.Stop();
+ try
+ {
+ var hubKey = NameHelperGenerator.CleanString(typeof(THub).FullName!);
+ var heartbeatGrain = NameHelperGenerator.GetConnectionHeartbeatGrain(_clusterClient, hubKey, connection.ConnectionId);
+ await heartbeatGrain.Stop();
+ }
+ catch (OrleansMessageRejectionException ex)
+ {
+ // Silo was restarted - heartbeat grain no longer exists
+ _logger.LogDebug(ex, "Heartbeat grain unavailable during disconnect for {ConnectionId}", connection.ConnectionId);
+ }
}
if (subscription is not null)
{
using (subscription)
{
- var removalTasks = subscription.Grains
- .Select(grain => grain.RemoveConnection(connection.ConnectionId, subscription.Reference))
- .ToArray();
+ try
+ {
+ var removalTasks = subscription.Grains
+ .Select(grain => SafeRemoveConnectionAsync(grain, connection.ConnectionId, subscription.Reference))
+ .ToArray();
- if (removalTasks.Length > 0)
+ if (removalTasks.Length > 0)
+ {
+ await Task.WhenAll(removalTasks);
+ }
+ }
+ catch (Exception ex)
{
- await Task.WhenAll(removalTasks);
+ _logger.LogDebug(ex, "Failed to remove connections from grains during disconnect for {ConnectionId}", connection.ConnectionId);
}
}
connection.Features.Set(null);
}
- var coordinator = NameHelperGenerator.GetConnectionCoordinatorGrain(_clusterClient);
- await coordinator.NotifyConnectionRemoved(connection.ConnectionId);
+ try
+ {
+ var coordinator = NameHelperGenerator.GetConnectionCoordinatorGrain(_clusterClient);
+ await coordinator.NotifyConnectionRemoved(connection.ConnectionId);
+ }
+ catch (OrleansMessageRejectionException ex)
+ {
+ // Silo was restarted - coordinator grain will be fresh anyway
+ _logger.LogDebug(ex, "Coordinator grain unavailable during disconnect for {ConnectionId}", connection.ConnectionId);
+ }
+ }
+
+ private static async Task SafeRemoveConnectionAsync(IObserverConnectionManager grain, string connectionId, ISignalRObserver reference)
+ {
+ try
+ {
+ await grain.RemoveConnection(connectionId, reference);
+ }
+ catch (OrleansMessageRejectionException)
+ {
+ // Grain was on old silo - nothing to clean up
+ }
}
public override Task SendAllAsync(string methodName, object?[] args, CancellationToken cancellationToken = new())
@@ -190,35 +255,34 @@ public override Task SendGroupAsync(string groupName, string methodName, object?
}
}
- public override Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object?[] args,
+ public override async Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object?[] args,
CancellationToken cancellationToken = new())
{
var message = new InvocationMessage(methodName, args);
if (_orleansSignalOptions.Value.GroupPartitionCount > 1)
{
- return Task.Run(() => NameHelperGenerator.GetGroupCoordinatorGrain(_clusterClient)
+ await Task.Run(() => NameHelperGenerator.GetGroupCoordinatorGrain(_clusterClient)
.SendToGroups(groupNames.ToArray(), message), cancellationToken);
+ return;
}
- // For potentially many groups, use fire-and-forget to avoid memory issues
- _ = Task.Run(async () =>
+ // Send to all groups in parallel for better performance
+ var tasks = new List(groupNames.Count);
+ foreach (var groupName in groupNames)
{
- foreach (var groupName in groupNames)
- {
- try
- {
- var groupGrain = NameHelperGenerator.GetSignalRGroupGrain(_clusterClient, groupName);
- await groupGrain.SendToGroup(message).ConfigureAwait(false);
- }
- catch (Exception ex)
- {
- _logger.LogError(ex, "Failed to send to group {GroupName}", groupName);
- }
- }
- }, cancellationToken);
+ var groupGrain = NameHelperGenerator.GetSignalRGroupGrain(_clusterClient, groupName);
+ tasks.Add(Task.Run(() => groupGrain.SendToGroup(message), cancellationToken));
+ }
- return Task.CompletedTask;
+ try
+ {
+ await Task.WhenAll(tasks);
+ }
+ catch (Exception ex)
+ {
+ _logger.LogError(ex, "Failed to send to one or more groups");
+ }
}
public override Task SendGroupExceptAsync(string groupName, string methodName, object?[] args,
@@ -244,29 +308,27 @@ public override Task SendUserAsync(string userId, string methodName, object?[] a
return Task.Run(() => NameHelperGenerator.GetSignalRUserGrain(_clusterClient, userId).SendToUser(message), cancellationToken);
}
- public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object?[] args,
+ public override async Task SendUsersAsync(IReadOnlyList userIds, string methodName, object?[] args,
CancellationToken cancellationToken = new())
{
var message = new InvocationMessage(methodName, args);
- // For potentially many users, use fire-and-forget to avoid memory issues
- _ = Task.Run(async () =>
+ // Send to all users in parallel for better performance
+ var tasks = new List(userIds.Count);
+ foreach (var userId in userIds)
{
- foreach (var userId in userIds)
- {
- try
- {
- var userGrain = NameHelperGenerator.GetSignalRUserGrain(_clusterClient, userId);
- await userGrain.SendToUser(message).ConfigureAwait(false);
- }
- catch (Exception ex)
- {
- _logger.LogError(ex, "Failed to send to user {UserId}", userId);
- }
- }
- }, cancellationToken);
+ var userGrain = NameHelperGenerator.GetSignalRUserGrain(_clusterClient, userId);
+ tasks.Add(Task.Run(() => userGrain.SendToUser(message), cancellationToken));
+ }
- return Task.CompletedTask;
+ try
+ {
+ await Task.WhenAll(tasks);
+ }
+ catch (Exception ex)
+ {
+ _logger.LogError(ex, "Failed to send to one or more users");
+ }
}
public override async Task AddToGroupAsync(string connectionId, string groupName,
@@ -474,15 +536,20 @@ await Task.Run(() => NameHelperGenerator.GetInvocationGrain(_clusterClient
public override bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type? type)
{
- var returnType = NameHelperGenerator.GetInvocationGrain(_clusterClient, invocationId).TryGetReturnType();
+ var returnTypeTask = NameHelperGenerator.GetInvocationGrain(_clusterClient, invocationId).TryGetReturnType();
var timeSpan = TimeIntervalHelper.GetClientTimeoutInterval(_orleansSignalOptions, _globalHubOptions, _hubOptions);
- Task.WaitAny(returnType, Task.Delay(timeSpan * 0.8));
+ var timeout = TimeSpan.FromMilliseconds(timeSpan.TotalMilliseconds * 0.8);
+
+ // Use async wait with timeout to avoid blocking thread pool threads
+ // This is required because the base class method is synchronous
+ var completed = returnTypeTask.Wait(timeout);
- if (returnType.IsCompleted)
+ if (completed && returnTypeTask.IsCompletedSuccessfully)
{
- type = returnType.Result.GetReturnType();
- return returnType.Result.Result;
+ var result = returnTypeTask.Result;
+ type = result.GetReturnType();
+ return result.Result;
}
type = null;
diff --git a/ManagedCode.Orleans.SignalR.Server/Extensions/OrleansDependencyInjectionExtensions.cs b/ManagedCode.Orleans.SignalR.Server/Extensions/OrleansDependencyInjectionExtensions.cs
index 68d5d61..3426f39 100644
--- a/ManagedCode.Orleans.SignalR.Server/Extensions/OrleansDependencyInjectionExtensions.cs
+++ b/ManagedCode.Orleans.SignalR.Server/Extensions/OrleansDependencyInjectionExtensions.cs
@@ -1,11 +1,9 @@
using System;
-using System.Reflection;
using ManagedCode.Orleans.SignalR.Core.Config;
using ManagedCode.Orleans.SignalR.Core.HubContext;
using ManagedCode.Orleans.SignalR.Core.SignalR;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.DependencyInjection;
-using Orleans;
using Orleans.Configuration;
using Orleans.Hosting;
diff --git a/ManagedCode.Orleans.SignalR.Server/Helpers/PersistentStateExtensions.cs b/ManagedCode.Orleans.SignalR.Server/Helpers/PersistentStateExtensions.cs
index 6b1eff1..704f4aa 100644
--- a/ManagedCode.Orleans.SignalR.Server/Helpers/PersistentStateExtensions.cs
+++ b/ManagedCode.Orleans.SignalR.Server/Helpers/PersistentStateExtensions.cs
@@ -1,4 +1,5 @@
using System;
+using System.Threading;
using System.Threading.Tasks;
using Orleans.Runtime;
using Orleans.Storage;
@@ -7,12 +8,19 @@ namespace ManagedCode.Orleans.SignalR.Server.Helpers;
internal static class PersistentStateExtensions
{
+ private const int MaxRetries = 5;
+
+ ///
+ /// Safely writes state with retry on ETag conflicts.
+ /// Handles both InconsistentStateException (persistent storage) and
+ /// MemoryStorageEtagMismatchException (memory storage) for development scenarios.
+ ///
public static async Task WriteStateSafeAsync(this IPersistentState state, Func applyChanges)
{
ArgumentNullException.ThrowIfNull(state);
ArgumentNullException.ThrowIfNull(applyChanges);
- while (true)
+ for (var retry = 0; retry < MaxRetries; retry++)
{
try
{
@@ -26,8 +34,85 @@ public static async Task WriteStateSafeAsync(this IPersistentState
}
catch (InconsistentStateException)
{
+ // Persistent storage ETag conflict
+ await state.ReadStateAsync();
+ }
+ catch (Exception ex) when (IsEtagMismatch(ex))
+ {
+ // Memory storage ETag conflict (development/testing)
await state.ReadStateAsync();
}
}
+
+ // Final attempt without catching - let it throw if still failing
+ if (!applyChanges(state.State))
+ {
+ return false;
+ }
+ await state.WriteStateAsync();
+ return true;
+ }
+
+ ///
+ /// Safely writes state with retry on ETag conflicts (no-change-detection version).
+ /// Use this when state has already been modified and just needs to be persisted.
+ ///
+ public static async Task WriteStateSafeAsync(this IPersistentState state, CancellationToken cancellationToken = default)
+ {
+ ArgumentNullException.ThrowIfNull(state);
+
+ for (var retry = 0; retry < MaxRetries; retry++)
+ {
+ try
+ {
+ await state.WriteStateAsync(cancellationToken);
+ return;
+ }
+ catch (InconsistentStateException)
+ {
+ await state.ReadStateAsync(cancellationToken);
+ }
+ catch (Exception ex) when (IsEtagMismatch(ex))
+ {
+ await state.ReadStateAsync(cancellationToken);
+ }
+ }
+
+ // Final attempt - let it throw if still failing
+ await state.WriteStateAsync(cancellationToken);
+ }
+
+ ///
+ /// Safely clears state with retry on ETag conflicts.
+ ///
+ public static async Task ClearStateSafeAsync(this IPersistentState state, CancellationToken cancellationToken = default)
+ {
+ ArgumentNullException.ThrowIfNull(state);
+
+ for (var retry = 0; retry < MaxRetries; retry++)
+ {
+ try
+ {
+ await state.ClearStateAsync(cancellationToken);
+ return;
+ }
+ catch (InconsistentStateException)
+ {
+ await state.ReadStateAsync(cancellationToken);
+ }
+ catch (Exception ex) when (IsEtagMismatch(ex))
+ {
+ await state.ReadStateAsync(cancellationToken);
+ }
+ }
+
+ // Final attempt - let it throw if still failing
+ await state.ClearStateAsync(cancellationToken);
+ }
+
+ private static bool IsEtagMismatch(Exception ex)
+ {
+ // Check for MemoryStorageEtagMismatchException without taking a hard dependency
+ return ex.GetType().Name == "MemoryStorageEtagMismatchException";
}
}
diff --git a/ManagedCode.Orleans.SignalR.Server/Properties/launchSettings.json b/ManagedCode.Orleans.SignalR.Server/Properties/launchSettings.json
new file mode 100644
index 0000000..c7d31de
--- /dev/null
+++ b/ManagedCode.Orleans.SignalR.Server/Properties/launchSettings.json
@@ -0,0 +1,12 @@
+{
+ "profiles": {
+ "ManagedCode.Orleans.SignalR.Server": {
+ "commandName": "Project",
+ "launchBrowser": true,
+ "environmentVariables": {
+ "ASPNETCORE_ENVIRONMENT": "Development"
+ },
+ "applicationUrl": "https://localhost:56459;http://localhost:56461"
+ }
+ }
+}
\ No newline at end of file
diff --git a/ManagedCode.Orleans.SignalR.Server/SignalRConnectionCoordinatorGrain.cs b/ManagedCode.Orleans.SignalR.Server/SignalRConnectionCoordinatorGrain.cs
index 4fadee0..7aba4d5 100644
--- a/ManagedCode.Orleans.SignalR.Server/SignalRConnectionCoordinatorGrain.cs
+++ b/ManagedCode.Orleans.SignalR.Server/SignalRConnectionCoordinatorGrain.cs
@@ -1,7 +1,9 @@
using System;
+using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
-using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using ManagedCode.Orleans.SignalR.Core.Config;
@@ -9,12 +11,14 @@
using ManagedCode.Orleans.SignalR.Core.Interfaces;
using ManagedCode.Orleans.SignalR.Core.Models;
using ManagedCode.Orleans.SignalR.Core.SignalR;
+using ManagedCode.Orleans.SignalR.Server.Helpers;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Orleans;
using Orleans.Concurrency;
using Orleans.Runtime;
+using static ManagedCode.Orleans.SignalR.Core.Helpers.CollectionPool;
namespace ManagedCode.Orleans.SignalR.Server;
@@ -25,10 +29,12 @@ public sealed class SignalRConnectionCoordinatorGrain : Grain, ISignalRConnectio
private readonly ILogger _logger;
private readonly IOptions _options;
private readonly IPersistentState _state;
- private readonly Dictionary _connectionPartitions;
+ private readonly Dictionary _connectionPartitions;
+ private readonly HashSet _activePartitions;
private readonly int _connectionsPerPartitionHint;
private uint _basePartitionCount;
private int _currentPartitionCount;
+ private int _partitionEpoch;
public SignalRConnectionCoordinatorGrain(
ILogger logger,
@@ -39,7 +45,8 @@ public SignalRConnectionCoordinatorGrain(
_logger = logger;
_options = options;
_state = state;
- _connectionPartitions = new Dictionary(StringComparer.Ordinal);
+ _connectionPartitions = new Dictionary(StringComparer.Ordinal);
+ _activePartitions = new HashSet();
_connectionsPerPartitionHint = Math.Max(1, _options.Value.ConnectionsPerPartitionHint);
}
@@ -47,30 +54,46 @@ public override async Task OnActivateAsync(CancellationToken cancellationToken)
{
await _state.ReadStateAsync(cancellationToken);
_state.State ??= new ConnectionCoordinatorState();
+
var partitions = EnsureOrdinalDictionary(_state.State.ConnectionPartitions);
_connectionPartitions.Clear();
+ _activePartitions.Clear();
+
foreach (var kvp in partitions)
{
_connectionPartitions[kvp.Key] = kvp.Value;
+ _activePartitions.Add(kvp.Value.PartitionId);
}
+
_state.State.ConnectionPartitions = _connectionPartitions;
_basePartitionCount = Math.Max(1u, _options.Value.ConnectionPartitionCount);
_currentPartitionCount = _state.State.CurrentPartitionCount;
+ _partitionEpoch = Math.Max(1, _state.State.PartitionEpoch);
+
+ // Ensure partition count is at least base, but preserve higher counts to maintain routing consistency
if (_currentPartitionCount <= 0 || _currentPartitionCount < _basePartitionCount)
{
_currentPartitionCount = (int)_basePartitionCount;
_state.State.CurrentPartitionCount = _currentPartitionCount;
}
+ // Only reset to base if truly empty AND partition count was scaled up
+ // This preserves routing consistency for connections that might reconnect
else if (_connectionPartitions.Count == 0 && _currentPartitionCount > _basePartitionCount)
{
_currentPartitionCount = (int)_basePartitionCount;
_state.State.CurrentPartitionCount = _currentPartitionCount;
+ // Reset epoch when scaling back to base with no connections
+ _partitionEpoch = 1;
+ _state.State.PartitionEpoch = _partitionEpoch;
}
_logger.LogInformation(
- "Connection coordinator activated with base partition count {PartitionCount} and hint {ConnectionsPerPartition}",
+ "Connection coordinator activated with base partition count {PartitionCount}, current {CurrentPartitionCount}, epoch {Epoch}, hint {ConnectionsPerPartition}, tracked connections {TrackedConnections}",
_basePartitionCount,
- _connectionsPerPartitionHint);
+ _currentPartitionCount,
+ _partitionEpoch,
+ _connectionsPerPartitionHint,
+ _connectionPartitions.Count);
await base.OnActivateAsync(cancellationToken);
}
@@ -79,10 +102,10 @@ public Task GetPartitionCount()
return Task.FromResult(_currentPartitionCount);
}
- public Task GetPartitionForConnection(string connectionId)
+ public async Task GetPartitionForConnection(string connectionId)
{
var stopwatch = Stopwatch.StartNew();
- var partition = GetOrAssignPartition(connectionId);
+ var (partition, wasNew, wasReassigned) = GetOrAssignPartitionWithEpoch(connectionId);
stopwatch.Stop();
if (stopwatch.Elapsed > TimeSpan.FromMilliseconds(500))
@@ -94,157 +117,288 @@ public Task GetPartitionForConnection(string connectionId)
_connectionPartitions.Count);
}
- return Task.FromResult(partition);
+ // Persist state if a new partition was assigned or reassigned due to epoch change
+ // Use safe write with retry for both persistent and memory storage ETag conflicts
+ if (wasNew || wasReassigned)
+ {
+ await _state.WriteStateSafeAsync(state =>
+ {
+ // Re-sync local dictionaries to state on each retry (ReadStateAsync creates new state object)
+ state.ConnectionPartitions = _connectionPartitions;
+ state.CurrentPartitionCount = _currentPartitionCount;
+ state.PartitionEpoch = _partitionEpoch;
+ return true;
+ });
+ }
+
+ return partition;
}
public async Task SendToAll(HubMessage message)
{
- var partitions = GetActivePartitions();
- if (partitions.Count == 0)
+ var partitionCount = _activePartitions.Count;
+ if (partitionCount == 0)
{
return;
}
- var distribution = _connectionPartitions
- .GroupBy(static kvp => kvp.Value)
- .Select(group => $"{group.Key}:{group.Count()}")
- .ToArray();
- _logger.LogInformation("Sending to all partitions {Distribution}", string.Join(",", distribution));
+ // Use ArrayPool for task collection to reduce allocations
+ var tasks = ArrayPool.Shared.Rent(partitionCount);
+ try
+ {
+ var hubKey = this.GetPrimaryKeyString();
+ var taskIndex = 0;
+
+ foreach (var partitionId in _activePartitions)
+ {
+ var partitionGrain = NameHelperGenerator.GetConnectionPartitionGrain(GrainFactory, hubKey, partitionId);
+ tasks[taskIndex++] = partitionGrain.SendToPartition(message);
+ }
- var tasks = new List(partitions.Count);
- foreach (var partitionId in partitions)
+ await Task.WhenAll(tasks.AsSpan(0, taskIndex));
+ }
+ finally
{
- var partitionGrain = NameHelperGenerator.GetConnectionPartitionGrain(GrainFactory, this.GetPrimaryKeyString(), partitionId);
- tasks.Add(partitionGrain.SendToPartition(message));
+ ArrayPool.Shared.Return(tasks, clearArray: true);
}
-
- await Task.WhenAll(tasks);
}
public async Task SendToAllExcept(HubMessage message, string[] excludedConnectionIds)
{
- var excludedByPartition = new Dictionary>();
- foreach (var connectionId in excludedConnectionIds)
+ var partitionCount = _activePartitions.Count;
+ if (partitionCount == 0)
{
- var partition = GetOrAssignPartition(connectionId);
- if (!excludedByPartition.TryGetValue(partition, out var list))
- {
- list = new List();
- excludedByPartition[partition] = list;
- }
- list.Add(connectionId);
+ return;
}
- var partitions = GetActivePartitions();
- if (partitions.Count == 0)
+ // Group excluded connections by partition using CollectionsMarshal for efficient access
+ var excludedByPartition = CollectionPool.GetIntListDictionary();
+ try
{
- return;
- }
+ foreach (var connectionId in excludedConnectionIds)
+ {
+ var (partition, _, _) = GetOrAssignPartitionWithEpoch(connectionId);
+ ref var list = ref CollectionsMarshal.GetValueRefOrAddDefault(excludedByPartition, partition, out var exists);
+ if (!exists)
+ {
+ list = CollectionPool.GetStringList();
+ }
+ list!.Add(connectionId);
+ }
- var tasks = new List(partitions.Count);
- foreach (var partitionId in partitions)
+ // Use ArrayPool for task collection
+ var tasks = ArrayPool.Shared.Rent(partitionCount);
+ try
+ {
+ var hubKey = this.GetPrimaryKeyString();
+ var taskIndex = 0;
+
+ foreach (var partitionId in _activePartitions)
+ {
+ var partitionGrain = NameHelperGenerator.GetConnectionPartitionGrain(GrainFactory, hubKey, partitionId);
+ var excluded = excludedByPartition.TryGetValue(partitionId, out var list)
+ ? CollectionsMarshal.AsSpan(list).ToArray()
+ : [];
+ tasks[taskIndex++] = partitionGrain.SendToPartitionExcept(message, excluded);
+ }
+
+ await Task.WhenAll(tasks.AsSpan(0, taskIndex));
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(tasks, clearArray: true);
+ }
+ }
+ finally
{
- var partitionGrain = NameHelperGenerator.GetConnectionPartitionGrain(GrainFactory, this.GetPrimaryKeyString(), partitionId);
- var excluded = excludedByPartition.TryGetValue(partitionId, out var list)
- ? list.ToArray()
- : Array.Empty();
- tasks.Add(partitionGrain.SendToPartitionExcept(message, excluded));
+ CollectionPool.Return(excludedByPartition);
}
-
- await Task.WhenAll(tasks);
}
public async Task SendToConnection(HubMessage message, string connectionId)
{
- var partition = GetOrAssignPartition(connectionId);
+ var (partition, _, _) = GetOrAssignPartitionWithEpoch(connectionId);
var partitionGrain = NameHelperGenerator.GetConnectionPartitionGrain(GrainFactory, this.GetPrimaryKeyString(), partition);
return await partitionGrain.SendToConnection(message, connectionId);
}
public async Task SendToConnections(HubMessage message, string[] connectionIds)
{
- var connectionsByPartition = new Dictionary>();
- foreach (var connectionId in connectionIds)
+ if (connectionIds.Length == 0)
{
- var partition = GetOrAssignPartition(connectionId);
- if (!connectionsByPartition.TryGetValue(partition, out var list))
- {
- list = new List();
- connectionsByPartition[partition] = list;
- }
- list.Add(connectionId);
+ return;
}
- if (connectionsByPartition.Count == 0)
+ // Group connections by partition using CollectionsMarshal for efficient access
+ var connectionsByPartition = GetIntListDictionary();
+ try
{
- return;
- }
+ foreach (var connectionId in connectionIds)
+ {
+ var (partition, _, _) = GetOrAssignPartitionWithEpoch(connectionId);
+ ref var list = ref CollectionsMarshal.GetValueRefOrAddDefault(connectionsByPartition, partition, out var exists);
+ if (!exists)
+ {
+ list = GetStringList();
+ }
+ list!.Add(connectionId);
+ }
- var tasks = new List(connectionsByPartition.Count);
- foreach (var kvp in connectionsByPartition)
+ if (connectionsByPartition.Count == 0)
+ {
+ return;
+ }
+
+ // Use ArrayPool for task collection
+ var tasks = ArrayPool.Shared.Rent(connectionsByPartition.Count);
+ try
+ {
+ var hubKey = this.GetPrimaryKeyString();
+ var taskIndex = 0;
+
+ foreach (var kvp in connectionsByPartition)
+ {
+ var partitionGrain = NameHelperGenerator.GetConnectionPartitionGrain(GrainFactory, hubKey, kvp.Key);
+ tasks[taskIndex++] = partitionGrain.SendToConnections(message, CollectionsMarshal.AsSpan(kvp.Value).ToArray());
+ }
+
+ await Task.WhenAll(tasks.AsSpan(0, taskIndex));
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(tasks, clearArray: true);
+ }
+ }
+ finally
{
- var partitionGrain = NameHelperGenerator.GetConnectionPartitionGrain(GrainFactory, this.GetPrimaryKeyString(), kvp.Key);
- tasks.Add(partitionGrain.SendToConnections(message, kvp.Value.ToArray()));
+ Return(connectionsByPartition);
}
-
- await Task.WhenAll(tasks);
}
- public Task NotifyConnectionRemoved(string connectionId)
+ public async Task NotifyConnectionRemoved(string connectionId)
{
- if (_connectionPartitions.Remove(connectionId))
+ if (_connectionPartitions.Remove(connectionId, out var removedAssignment))
{
- _logger.LogDebug("Removed connection {ConnectionId} from coordinator mapping.", connectionId);
+ var removedPartition = removedAssignment.PartitionId;
+ _logger.LogDebug("Removed connection {ConnectionId} from coordinator mapping (partition {Partition}, epoch {Epoch}).",
+ connectionId, removedPartition, removedAssignment.Epoch);
+
+ // Check if any other connections are using this partition
+ var partitionStillActive = false;
+ foreach (var assignment in _connectionPartitions.Values)
+ {
+ if (assignment.PartitionId == removedPartition)
+ {
+ partitionStillActive = true;
+ break;
+ }
+ }
+
+ if (!partitionStillActive)
+ {
+ _activePartitions.Remove(removedPartition);
+ }
+
if (_connectionPartitions.Count == 0 && _currentPartitionCount != _basePartitionCount)
{
- _logger.LogDebug("Resetting partition count to base value {PartitionCount} as no active connections remain.", _basePartitionCount);
+ _logger.LogDebug("Resetting partition count to base value {PartitionCount} and epoch to 1 as no active connections remain.", _basePartitionCount);
_currentPartitionCount = (int)_basePartitionCount;
_state.State.CurrentPartitionCount = _currentPartitionCount;
+ _partitionEpoch = 1;
+ _state.State.PartitionEpoch = _partitionEpoch;
+ _activePartitions.Clear();
}
- }
- return Task.CompletedTask;
+ // Persist state changes with safe retry for ETag conflicts
+ await _state.WriteStateSafeAsync(state =>
+ {
+ // Re-sync local dictionaries to state on each retry (ReadStateAsync creates new state object)
+ state.ConnectionPartitions = _connectionPartitions;
+ state.CurrentPartitionCount = _currentPartitionCount;
+ state.PartitionEpoch = _partitionEpoch;
+ return true;
+ });
+ }
}
public override async Task OnDeactivateAsync(DeactivationReason reason, CancellationToken cancellationToken)
{
_state.State.CurrentPartitionCount = _currentPartitionCount;
- if (_connectionPartitions.Count == 0)
+ _state.State.PartitionEpoch = _partitionEpoch;
+
+ try
{
- await _state.ClearStateAsync(cancellationToken);
+ if (_connectionPartitions.Count == 0)
+ {
+ await _state.ClearStateSafeAsync(cancellationToken);
+ }
+ else
+ {
+ await _state.WriteStateSafeAsync(cancellationToken);
+ }
}
- else
+ catch (OrleansMessageRejectionException ex)
{
- await _state.WriteStateAsync(cancellationToken);
+ // Storage grains may be unavailable during silo shutdown
+ _logger.LogDebug(ex, "Unable to persist state during deactivation for coordinator {HubKey} - storage unavailable.", this.GetPrimaryKeyString());
}
}
- private List GetActivePartitions()
+ ///
+ /// Gets or assigns a partition for a connection, handling epoch-based reassignment.
+ /// Returns (partitionId, wasNew, wasReassigned).
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private (int PartitionId, bool WasNew, bool WasReassigned) GetOrAssignPartitionWithEpoch(string connectionId)
{
- if (_connectionPartitions.Count == 0)
+ if (_connectionPartitions.TryGetValue(connectionId, out var existingAssignment))
{
- return Enumerable.Range(0, _currentPartitionCount).ToList();
- }
+ // Check if assignment is from current epoch
+ if (existingAssignment.Epoch == _partitionEpoch)
+ {
+ return (existingAssignment.PartitionId, false, false);
+ }
- return _connectionPartitions.Values
- .Distinct()
- .OrderBy(static partitionId => partitionId)
- .ToList();
- }
+ // Stale epoch - check if partition would be different with current partition count
+ var newPartition = PartitionHelper.GetPartitionId(connectionId, (uint)_currentPartitionCount);
- private int GetOrAssignPartition(string connectionId)
- {
- if (_connectionPartitions.TryGetValue(connectionId, out var partition))
- {
- return partition;
+ if (newPartition == existingAssignment.PartitionId)
+ {
+ // Same partition, just update epoch
+ var updatedAssignment = PartitionAssignment.Create(existingAssignment.PartitionId, _partitionEpoch);
+ _connectionPartitions[connectionId] = updatedAssignment;
+ _logger.LogDebug(
+ "Updated connection {ConnectionId} epoch from {OldEpoch} to {NewEpoch} (partition {Partition} unchanged)",
+ connectionId, existingAssignment.Epoch, _partitionEpoch, existingAssignment.PartitionId);
+ return (existingAssignment.PartitionId, false, true);
+ }
+
+ // Partition changed due to scaling - reassign
+ // Note: The old partition may still have this connection until cleanup
+ var reassignment = PartitionAssignment.Create(newPartition, _partitionEpoch);
+ _connectionPartitions[connectionId] = reassignment;
+ _activePartitions.Add(newPartition);
+
+ _logger.LogInformation(
+ "Reassigned connection {ConnectionId} from partition {OldPartition} (epoch {OldEpoch}) to partition {NewPartition} (epoch {NewEpoch}) due to scaling",
+ connectionId, existingAssignment.PartitionId, existingAssignment.Epoch, newPartition, _partitionEpoch);
+
+ return (newPartition, false, true);
}
+ // New connection - assign to partition with current epoch
var partitionCount = EnsurePartitionCapacity(_connectionPartitions.Count + 1);
- partition = PartitionHelper.GetPartitionId(connectionId, (uint)partitionCount);
- _connectionPartitions[connectionId] = partition;
+ var partition = PartitionHelper.GetPartitionId(connectionId, (uint)partitionCount);
+ var assignment = PartitionAssignment.Create(partition, _partitionEpoch);
- _logger.LogDebug("Assigned connection {ConnectionId} to partition {Partition} (partitionCount={PartitionCount})", connectionId, partition, partitionCount);
- return partition;
+ _connectionPartitions[connectionId] = assignment;
+ _activePartitions.Add(partition);
+
+ _logger.LogDebug(
+ "Assigned connection {ConnectionId} to partition {Partition} (epoch {Epoch}, partitionCount={PartitionCount})",
+ connectionId, partition, _partitionEpoch, partitionCount);
+
+ return (partition, true, false);
}
private int EnsurePartitionCapacity(int prospectiveConnections)
@@ -255,22 +409,27 @@ private int EnsurePartitionCapacity(int prospectiveConnections)
if (desired > _currentPartitionCount)
{
_logger.LogInformation(
- "Increasing connection partition count from {OldPartitionCount} to {NewPartitionCount} for {ConnectionCount} tracked connections.",
+ "Increasing connection partition count from {OldPartitionCount} to {NewPartitionCount} (epoch {OldEpoch} -> {NewEpoch}) for {ConnectionCount} tracked connections.",
_currentPartitionCount,
desired,
+ _partitionEpoch,
+ _partitionEpoch + 1,
prospectiveConnections);
+
_currentPartitionCount = desired;
+ _partitionEpoch++;
_state.State.CurrentPartitionCount = _currentPartitionCount;
+ _state.State.PartitionEpoch = _partitionEpoch;
}
return _currentPartitionCount;
}
- private static Dictionary EnsureOrdinalDictionary(Dictionary? dictionary)
+ private static Dictionary EnsureOrdinalDictionary(Dictionary? dictionary)
{
if (dictionary is null)
{
- return new Dictionary(StringComparer.Ordinal);
+ return new Dictionary(StringComparer.Ordinal);
}
if (dictionary.Comparer == StringComparer.Ordinal)
@@ -278,6 +437,6 @@ private static Dictionary EnsureOrdinalDictionary(Dictionary(dictionary, StringComparer.Ordinal);
+ return new Dictionary(dictionary, StringComparer.Ordinal);
}
}
diff --git a/ManagedCode.Orleans.SignalR.Server/SignalRConnectionHeartbeatGrain.cs b/ManagedCode.Orleans.SignalR.Server/SignalRConnectionHeartbeatGrain.cs
index 8ea85ad..1bac1a9 100644
--- a/ManagedCode.Orleans.SignalR.Server/SignalRConnectionHeartbeatGrain.cs
+++ b/ManagedCode.Orleans.SignalR.Server/SignalRConnectionHeartbeatGrain.cs
@@ -4,7 +4,7 @@
using ManagedCode.Orleans.SignalR.Core.Config;
using ManagedCode.Orleans.SignalR.Core.Interfaces;
using ManagedCode.Orleans.SignalR.Core.Models;
-using ManagedCode.Orleans.SignalR.Core.SignalR;
+using ManagedCode.Orleans.SignalR.Server.Helpers;
using Microsoft.Extensions.Logging;
using Orleans;
using Orleans.Concurrency;
@@ -48,32 +48,46 @@ public override async Task OnActivateAsync(CancellationToken cancellationToken)
public async Task Start(ConnectionHeartbeatRegistration registration)
{
_registration = registration;
- _state.State.Registration = registration;
ResetTimer(registration.Interval);
_logger.LogDebug("Heartbeat started for connection grain {Key} (hub={Hub}, partitioned={Partitioned}, partitionId={PartitionId}).",
this.GetPrimaryKeyString(), registration.HubKey, registration.UsePartitioning, registration.PartitionId);
- await _state.WriteStateAsync();
+ await _state.WriteStateSafeAsync(state =>
+ {
+ state.Registration = registration;
+ return true;
+ });
}
public async Task Stop()
{
ResetTimer(null);
- _state.State.Registration = null;
_registration = null;
_logger.LogDebug("Heartbeat stopped for connection grain {Key}.", this.GetPrimaryKeyString());
- await _state.WriteStateAsync();
+ await _state.WriteStateSafeAsync(state =>
+ {
+ state.Registration = null;
+ return true;
+ });
}
public override async Task OnDeactivateAsync(DeactivationReason reason, CancellationToken cancellationToken)
{
ResetTimer(null);
- if (_state.State.Registration is null)
+ try
{
- await _state.ClearStateAsync(cancellationToken);
+ if (_state.State.Registration is null)
+ {
+ await _state.ClearStateSafeAsync(cancellationToken);
+ }
+ else
+ {
+ await _state.WriteStateSafeAsync(cancellationToken);
+ }
}
- else
+ catch (OrleansMessageRejectionException ex)
{
- await _state.WriteStateAsync(cancellationToken);
+ // Storage grains may be unavailable during silo shutdown
+ _logger.LogDebug(ex, "Unable to persist state during deactivation for grain {Key} - storage unavailable.", this.GetPrimaryKeyString());
}
}
@@ -96,20 +110,23 @@ private void ResetTimer(TimeSpan? interval)
}
}
- private Task OnTimerTickAsync(object? state)
+ private Task OnTimerTickAsync(object? _)
{
- if (_registration is null)
+ // Capture registration to avoid null reference if Stop() is called during reentrant execution
+ var registration = _registration;
+ if (registration is null)
{
return Task.CompletedTask;
}
- var grainIds = _registration.GrainIds;
+ var grainIds = registration.GrainIds;
if (grainIds.IsDefaultOrEmpty)
{
return Task.CompletedTask;
}
- var connectionId = _registration.ConnectionId;
+ var connectionId = registration.ConnectionId;
+ var observer = registration.Observer;
try
{
foreach (var grainId in grainIds)
@@ -118,9 +135,9 @@ private Task OnTimerTickAsync(object? state)
var manager = grain.AsReference();
if (!string.IsNullOrEmpty(connectionId))
{
- _ = manager.AddConnection(connectionId, _registration.Observer);
+ _ = manager.AddConnection(connectionId, observer);
}
- _ = manager.Ping(_registration.Observer);
+ _ = manager.Ping(observer);
}
}
catch (Exception ex)
diff --git a/ManagedCode.Orleans.SignalR.Server/SignalRConnectionHolderGrain.cs b/ManagedCode.Orleans.SignalR.Server/SignalRConnectionHolderGrain.cs
index afdfc23..f2241e1 100644
--- a/ManagedCode.Orleans.SignalR.Server/SignalRConnectionHolderGrain.cs
+++ b/ManagedCode.Orleans.SignalR.Server/SignalRConnectionHolderGrain.cs
@@ -66,20 +66,21 @@ public async Task RemoveConnection(string connectionId, ISignalRObserver observe
}
}
- public async Task SendToAll(HubMessage message)
+ public Task SendToAll(HubMessage message)
{
Logs.SendToAll(Logger, nameof(SignalRConnectionHolderGrain), this.GetPrimaryKeyString());
if (LiveObservers.Count > 0)
{
DispatchToLiveObservers(LiveObservers.Values, message);
- return;
+ return Task.CompletedTask;
}
- await Task.Run(() => ObserverManager.Notify(s => s.OnNextAsync(message)));
+ ObserverManager.Notify(s => s.OnNextAsync(message));
+ return Task.CompletedTask;
}
- public async Task SendToAllExcept(HubMessage message, string[] excludedConnectionIds)
+ public Task SendToAllExcept(HubMessage message, string[] excludedConnectionIds)
{
Logs.SendToAllExcept(Logger, nameof(SignalRConnectionHolderGrain), this.GetPrimaryKeyString(), excludedConnectionIds);
@@ -88,7 +89,7 @@ public async Task SendToAllExcept(HubMessage message, string[] excludedConnectio
var excluded = new HashSet(excludedConnectionIds, StringComparer.Ordinal);
var targets = LiveObservers.Where(kvp => !excluded.Contains(kvp.Key)).Select(kvp => kvp.Value);
DispatchToLiveObservers(targets, message);
- return;
+ return Task.CompletedTask;
}
var hashSet = new HashSet();
@@ -100,32 +101,33 @@ public async Task SendToAllExcept(HubMessage message, string[] excludedConnectio
}
}
- await Task.Run(() => ObserverManager.Notify(s => s.OnNextAsync(message),
- connection => !hashSet.Contains(connection.GetPrimaryKeyString())));
+ ObserverManager.Notify(s => s.OnNextAsync(message),
+ connection => !hashSet.Contains(connection.GetPrimaryKeyString()));
+ return Task.CompletedTask;
}
- public async Task SendToConnection(HubMessage message, string connectionId)
+ public Task SendToConnection(HubMessage message, string connectionId)
{
Logs.SendToConnection(Logger, nameof(SignalRConnectionHolderGrain), this.GetPrimaryKeyString(), connectionId);
if (!stateStorage.State.ConnectionIds.TryGetValue(connectionId, out var observer))
{
- return false;
+ return Task.FromResult(false);
}
if (TryGetLiveObserver(connectionId, out var liveObserver))
{
_ = liveObserver.OnNextAsync(message);
- return true;
+ return Task.FromResult(true);
}
- await Task.Run(() => ObserverManager.Notify(s => s.OnNextAsync(message),
- connection => connection.GetPrimaryKeyString() == observer));
+ ObserverManager.Notify(s => s.OnNextAsync(message),
+ connection => connection.GetPrimaryKeyString() == observer);
- return true;
+ return Task.FromResult(true);
}
- public async Task SendToConnections(HubMessage message, string[] connectionIds)
+ public Task SendToConnections(HubMessage message, string[] connectionIds)
{
Logs.SendToConnections(Logger, nameof(SignalRConnectionHolderGrain), this.GetPrimaryKeyString(), connectionIds);
@@ -144,7 +146,7 @@ public async Task SendToConnections(HubMessage message, string[] connectionIds)
if (targets is not null)
{
DispatchToLiveObservers(targets, message);
- return;
+ return Task.CompletedTask;
}
}
@@ -157,8 +159,9 @@ public async Task SendToConnections(HubMessage message, string[] connectionIds)
}
}
- await Task.Run(() => ObserverManager.Notify(s => s.OnNextAsync(message),
- connection => hashSet.Contains(connection.GetPrimaryKeyString())));
+ ObserverManager.Notify(s => s.OnNextAsync(message),
+ connection => hashSet.Contains(connection.GetPrimaryKeyString()));
+ return Task.CompletedTask;
}
public Task Ping(ISignalRObserver observer)
@@ -176,11 +179,11 @@ public override async Task OnDeactivateAsync(DeactivationReason reason, Cancella
if (!hasConnections)
{
- await stateStorage.ClearStateAsync(cancellationToken);
+ await stateStorage.ClearStateSafeAsync(cancellationToken);
}
else
{
- await stateStorage.WriteStateAsync(cancellationToken);
+ await stateStorage.WriteStateSafeAsync(cancellationToken);
}
}
diff --git a/ManagedCode.Orleans.SignalR.Server/SignalRConnectionPartitionGrain.cs b/ManagedCode.Orleans.SignalR.Server/SignalRConnectionPartitionGrain.cs
index 6e63309..fa14443 100644
--- a/ManagedCode.Orleans.SignalR.Server/SignalRConnectionPartitionGrain.cs
+++ b/ManagedCode.Orleans.SignalR.Server/SignalRConnectionPartitionGrain.cs
@@ -67,20 +67,21 @@ public async Task RemoveConnection(string connectionId, ISignalRObserver observe
}
}
- public async Task SendToPartition(HubMessage message)
+ public Task SendToPartition(HubMessage message)
{
Logs.SendToAll(Logger, nameof(SignalRConnectionPartitionGrain), this.GetPrimaryKeyLong().ToString(CultureInfo.InvariantCulture));
if (LiveObservers.Count > 0)
{
DispatchToLiveObservers(LiveObservers.Values, message);
- return;
+ return Task.CompletedTask;
}
- await Task.Run(() => ObserverManager.Notify(s => s.OnNextAsync(message)));
+ ObserverManager.Notify(s => s.OnNextAsync(message));
+ return Task.CompletedTask;
}
- public async Task SendToPartitionExcept(HubMessage message, string[] excludedConnectionIds)
+ public Task SendToPartitionExcept(HubMessage message, string[] excludedConnectionIds)
{
Logs.SendToAllExcept(Logger, nameof(SignalRConnectionPartitionGrain), this.GetPrimaryKeyLong().ToString(CultureInfo.InvariantCulture), excludedConnectionIds);
@@ -89,7 +90,7 @@ public async Task SendToPartitionExcept(HubMessage message, string[] excludedCon
var excluded = new HashSet(excludedConnectionIds, StringComparer.Ordinal);
var targets = LiveObservers.Where(kvp => !excluded.Contains(kvp.Key)).Select(kvp => kvp.Value);
DispatchToLiveObservers(targets, message);
- return;
+ return Task.CompletedTask;
}
var hashSet = new HashSet();
@@ -101,11 +102,12 @@ public async Task SendToPartitionExcept(HubMessage message, string[] excludedCon
}
}
- await Task.Run(() => ObserverManager.Notify(s => s.OnNextAsync(message),
- connection => !hashSet.Contains(connection.GetPrimaryKeyString())));
+ ObserverManager.Notify(s => s.OnNextAsync(message),
+ connection => !hashSet.Contains(connection.GetPrimaryKeyString()));
+ return Task.CompletedTask;
}
- public async Task SendToConnection(HubMessage message, string connectionId)
+ public Task SendToConnection(HubMessage message, string connectionId)
{
Logs.SendToConnection(Logger, nameof(SignalRConnectionPartitionGrain), this.GetPrimaryKeyLong().ToString(CultureInfo.InvariantCulture), connectionId);
@@ -116,13 +118,13 @@ public async Task SendToConnection(HubMessage message, string connectionId
connectionId,
stateStorage.State.ConnectionIds.Count,
LiveObservers.Count);
- return false;
+ return Task.FromResult(false);
}
if (TryGetLiveObserver(connectionId, out var live))
{
_ = live.OnNextAsync(message);
- return true;
+ return Task.FromResult(true);
}
Logger.LogDebug("Partition {PartitionId} falling back to observer manager for {ConnectionId} (live={LiveObserversCount}).",
@@ -130,13 +132,13 @@ public async Task SendToConnection(HubMessage message, string connectionId
connectionId,
LiveObservers.Count);
- await Task.Run(() => ObserverManager.Notify(s => s.OnNextAsync(message),
- connection => connection.GetPrimaryKeyString() == observer));
+ ObserverManager.Notify(s => s.OnNextAsync(message),
+ connection => connection.GetPrimaryKeyString() == observer);
- return true;
+ return Task.FromResult(true);
}
- public async Task SendToConnections(HubMessage message, string[] connectionIds)
+ public Task SendToConnections(HubMessage message, string[] connectionIds)
{
Logs.SendToConnections(Logger, nameof(SignalRConnectionPartitionGrain), this.GetPrimaryKeyLong().ToString(CultureInfo.InvariantCulture), connectionIds);
@@ -155,7 +157,7 @@ public async Task SendToConnections(HubMessage message, string[] connectionIds)
if (targets is not null)
{
DispatchToLiveObservers(targets, message);
- return;
+ return Task.CompletedTask;
}
}
@@ -168,8 +170,9 @@ public async Task SendToConnections(HubMessage message, string[] connectionIds)
}
}
- await Task.Run(() => ObserverManager.Notify(s => s.OnNextAsync(message),
- connection => hashSet.Contains(connection.GetPrimaryKeyString())));
+ ObserverManager.Notify(s => s.OnNextAsync(message),
+ connection => hashSet.Contains(connection.GetPrimaryKeyString()));
+ return Task.CompletedTask;
}
public Task Ping(ISignalRObserver observer)
@@ -185,13 +188,21 @@ public override async Task OnDeactivateAsync(DeactivationReason reason, Cancella
var hasConnections = stateStorage.State.ConnectionIds.Count > 0;
ClearObserverTracking();
- if (!hasConnections)
+ try
{
- await stateStorage.ClearStateAsync(cancellationToken);
+ if (!hasConnections)
+ {
+ await stateStorage.ClearStateSafeAsync(cancellationToken);
+ }
+ else
+ {
+ await stateStorage.WriteStateSafeAsync(cancellationToken);
+ }
}
- else
+ catch (OrleansMessageRejectionException ex)
{
- await stateStorage.WriteStateAsync(cancellationToken);
+ // Storage grains may be unavailable during silo shutdown
+ Logger.LogDebug(ex, "Unable to persist state during deactivation for partition {PartitionId} - storage unavailable.", this.GetPrimaryKeyLong());
}
}
diff --git a/ManagedCode.Orleans.SignalR.Server/SignalRGroupCoordinatorGrain.cs b/ManagedCode.Orleans.SignalR.Server/SignalRGroupCoordinatorGrain.cs
index 358e0e5..1f90f98 100644
--- a/ManagedCode.Orleans.SignalR.Server/SignalRGroupCoordinatorGrain.cs
+++ b/ManagedCode.Orleans.SignalR.Server/SignalRGroupCoordinatorGrain.cs
@@ -1,5 +1,8 @@
using System;
+using System.Buffers;
using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using ManagedCode.Orleans.SignalR.Core.Config;
@@ -7,6 +10,7 @@
using ManagedCode.Orleans.SignalR.Core.Interfaces;
using ManagedCode.Orleans.SignalR.Core.Models;
using ManagedCode.Orleans.SignalR.Core.SignalR;
+using ManagedCode.Orleans.SignalR.Server.Helpers;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
@@ -23,10 +27,15 @@ public sealed class SignalRGroupCoordinatorGrain : Grain, ISignalRGroupCoordinat
private readonly ILogger _logger;
private readonly IOptions _options;
private readonly IPersistentState _state;
+ private Dictionary GroupPartitions { get; } = new(StringComparer.Ordinal);
+ private Dictionary GroupMembership { get; } = new(StringComparer.Ordinal);
+ private readonly HashSet _activePartitions = [];
private readonly int _groupsPerPartitionHint;
private uint _basePartitionCount;
private string? _hubKey;
private int _currentPartitionCount;
+ private int _partitionEpoch;
+ private bool _stateDirty;
public SignalRGroupCoordinatorGrain(
ILogger logger,
@@ -44,23 +53,60 @@ public override async Task OnActivateAsync(CancellationToken cancellationToken)
{
await _state.ReadStateAsync(cancellationToken);
_state.State ??= new GroupCoordinatorState();
- _state.State.GroupPartitions = EnsureOrdinalDictionary(_state.State.GroupPartitions);
- _state.State.GroupMembership = EnsureOrdinalDictionary(_state.State.GroupMembership);
+
+ // Copy persisted state to local dictionaries
+ var persistedPartitions = EnsureOrdinalDictionary(_state.State.GroupPartitions);
+ var persistedMembership = EnsureOrdinalMembershipDictionary(_state.State.GroupMembership);
+
+ GroupPartitions.Clear();
+ GroupMembership.Clear();
+ _activePartitions.Clear();
+
+ foreach (var kvp in persistedPartitions)
+ {
+ GroupPartitions[kvp.Key] = kvp.Value;
+ _activePartitions.Add(kvp.Value.PartitionId);
+ }
+
+ foreach (var kvp in persistedMembership)
+ {
+ GroupMembership[kvp.Key] = kvp.Value;
+ }
+
+ // Set state to reference local dictionaries
+ _state.State.GroupPartitions = GroupPartitions;
+ _state.State.GroupMembership = GroupMembership;
+
_basePartitionCount = Math.Max(1u, _options.Value.GroupPartitionCount);
_currentPartitionCount = _state.State.CurrentPartitionCount;
+ _partitionEpoch = Math.Max(1, _state.State.PartitionEpoch);
+
+ // Ensure partition count is at least base, but preserve higher counts to maintain routing consistency
if (_currentPartitionCount <= 0 || _currentPartitionCount < _basePartitionCount)
{
_currentPartitionCount = (int)_basePartitionCount;
_state.State.CurrentPartitionCount = _currentPartitionCount;
}
+ // Only reset to base if truly empty AND partition count was scaled up
else if (GroupPartitions.Count == 0 && _currentPartitionCount > _basePartitionCount)
{
_currentPartitionCount = (int)_basePartitionCount;
_state.State.CurrentPartitionCount = _currentPartitionCount;
+ // Reset epoch when scaling back to base with no groups
+ _partitionEpoch = 1;
+ _state.State.PartitionEpoch = _partitionEpoch;
}
+
_hubKey = this.GetPrimaryKeyString();
+ _stateDirty = false;
- _logger.LogInformation("Group coordinator activated with base partition count {PartitionCount} and hint {GroupsPerPartition}", _basePartitionCount, _groupsPerPartitionHint);
+ _logger.LogInformation(
+ "Group coordinator activated with base partition count {PartitionCount}, current {CurrentPartitionCount}, epoch {Epoch}, hint {GroupsPerPartition}, tracked groups {TrackedGroups}",
+ _basePartitionCount,
+ _currentPartitionCount,
+ _partitionEpoch,
+ _groupsPerPartitionHint,
+ GroupPartitions.Count);
await base.OnActivateAsync(cancellationToken);
}
@@ -77,84 +123,104 @@ public Task GetPartitionCount()
public Task GetPartitionForGroup(string groupName)
{
- var partition = GetOrAssignPartition(groupName);
+ var (partition, _, _) = GetOrAssignPartitionWithEpoch(groupName);
return Task.FromResult(partition);
}
public async Task SendToGroup(string groupName, HubMessage message)
{
- var partition = GetOrAssignPartition(groupName);
+ var (partition, _, _) = GetOrAssignPartitionWithEpoch(groupName);
var partitionGrain = await GetPartitionGrainAsync(partition);
await partitionGrain.SendToGroups(message, new[] { groupName });
}
public async Task SendToGroupExcept(string groupName, HubMessage message, string[] excludedConnectionIds)
{
- var partition = GetOrAssignPartition(groupName);
+ var (partition, _, _) = GetOrAssignPartitionWithEpoch(groupName);
var partitionGrain = await GetPartitionGrainAsync(partition);
await partitionGrain.SendToGroupsExcept(message, new[] { groupName }, excludedConnectionIds);
}
public async Task SendToGroups(string[] groupNames, HubMessage message)
{
+ // Group by partition using CollectionsMarshal for efficient access
var groupsByPartition = new Dictionary>();
foreach (var groupName in groupNames)
{
- var partition = GetOrAssignPartition(groupName);
- if (!groupsByPartition.TryGetValue(partition, out var list))
+ var (partition, _, _) = GetOrAssignPartitionWithEpoch(groupName);
+ ref var list = ref CollectionsMarshal.GetValueRefOrAddDefault(groupsByPartition, partition, out var exists);
+ if (!exists)
{
list = new List();
- groupsByPartition[partition] = list;
}
- list.Add(groupName);
+ list!.Add(groupName);
+ }
+
+ if (groupsByPartition.Count == 0)
+ {
+ return;
}
- if (groupsByPartition.Count < 100)
+ // Use ArrayPool for task collection
+ var tasks = ArrayPool.Shared.Rent(groupsByPartition.Count);
+ try
{
- var tasks = new List(groupsByPartition.Count);
+ var taskIndex = 0;
foreach (var kvp in groupsByPartition)
{
var partitionGrain = await GetPartitionGrainAsync(kvp.Key);
- tasks.Add(partitionGrain.SendToGroups(message, kvp.Value.ToArray()));
+ tasks[taskIndex++] = partitionGrain.SendToGroups(message, CollectionsMarshal.AsSpan(kvp.Value).ToArray());
}
- await Task.WhenAll(tasks);
+
+ await Task.WhenAll(tasks.AsSpan(0, taskIndex));
}
- else
+ catch (Exception ex)
{
- foreach (var kvp in groupsByPartition)
- {
- var partitionId = kvp.Key;
- _ = Task.Run(async () =>
- {
- try
- {
- var partitionGrain = await GetPartitionGrainAsync(partitionId);
- await partitionGrain.SendToGroups(message, kvp.Value.ToArray());
- }
- catch (Exception ex)
- {
- _logger.LogError(ex, "Failed to send to groups in partition {PartitionId}", partitionId);
- }
- });
- }
+ _logger.LogError(ex, "Failed to send to one or more group partitions");
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(tasks, clearArray: true);
}
}
public async Task AddConnectionToGroup(string groupName, string connectionId, ISignalRObserver observer)
{
- var partition = GetOrAssignPartition(groupName);
+ var (partition, _, _) = GetOrAssignPartitionWithEpoch(groupName);
var membership = GroupMembership.TryGetValue(groupName, out var count) ? count + 1 : 1;
GroupMembership[groupName] = membership;
var partitionGrain = await GetPartitionGrainAsync(partition);
await partitionGrain.AddConnectionToGroup(groupName, connectionId, observer);
+
+ // Persist state changes to ensure consistency after reactivation
+ if (_stateDirty)
+ {
+ await _state.WriteStateSafeAsync(state =>
+ {
+ // Re-sync local dictionaries to state on each retry (ReadStateAsync creates new state object)
+ state.GroupPartitions = GroupPartitions;
+ state.GroupMembership = GroupMembership;
+ state.CurrentPartitionCount = _currentPartitionCount;
+ state.PartitionEpoch = _partitionEpoch;
+ return true;
+ });
+ _stateDirty = false;
+ }
}
public async Task RemoveConnectionFromGroup(string groupName, string connectionId, ISignalRObserver observer)
{
- var partition = GroupPartitions.TryGetValue(groupName, out var existingPartition)
- ? existingPartition
- : PartitionHelper.GetPartitionId(groupName, (uint)_currentPartitionCount);
+ int partition;
+ if (GroupPartitions.TryGetValue(groupName, out var existingAssignment))
+ {
+ partition = existingAssignment.PartitionId;
+ }
+ else
+ {
+ partition = PartitionHelper.GetPartitionId(groupName, (uint)_currentPartitionCount);
+ }
+
var partitionGrain = await GetPartitionGrainAsync(partition);
await partitionGrain.RemoveConnectionFromGroup(groupName, connectionId, observer);
@@ -169,24 +235,54 @@ public async Task RemoveConnectionFromGroup(string groupName, string connectionI
GroupMembership[groupName] = count - 1;
}
}
+
+ // Persist state changes to ensure consistency after reactivation
+ if (_stateDirty)
+ {
+ await _state.WriteStateSafeAsync(state =>
+ {
+ // Re-sync local dictionaries to state on each retry (ReadStateAsync creates new state object)
+ state.GroupPartitions = GroupPartitions;
+ state.GroupMembership = GroupMembership;
+ state.CurrentPartitionCount = _currentPartitionCount;
+ state.PartitionEpoch = _partitionEpoch;
+ return true;
+ });
+ _stateDirty = false;
+ }
}
- public Task NotifyGroupRemoved(string groupName)
+ public async Task NotifyGroupRemoved(string groupName)
{
ReleaseGroup(groupName);
- return Task.CompletedTask;
+
+ if (_stateDirty)
+ {
+ await _state.WriteStateSafeAsync(state =>
+ {
+ // Re-sync local dictionaries to state on each retry (ReadStateAsync creates new state object)
+ state.GroupPartitions = GroupPartitions;
+ state.GroupMembership = GroupMembership;
+ state.CurrentPartitionCount = _currentPartitionCount;
+ state.PartitionEpoch = _partitionEpoch;
+ return true;
+ });
+ _stateDirty = false;
+ }
}
public override async Task OnDeactivateAsync(DeactivationReason reason, CancellationToken cancellationToken)
{
_state.State.CurrentPartitionCount = _currentPartitionCount;
+ _state.State.PartitionEpoch = _partitionEpoch;
+
if (GroupPartitions.Count == 0)
{
- await _state.ClearStateAsync(cancellationToken);
+ await _state.ClearStateSafeAsync(cancellationToken);
}
else
{
- await _state.WriteStateAsync(cancellationToken);
+ await _state.WriteStateSafeAsync(cancellationToken);
}
}
@@ -198,19 +294,63 @@ private async Task GetPartitionGrainAsync(int parti
return partitionGrain;
}
- private int GetOrAssignPartition(string groupName)
+ ///
+ /// Gets or assigns a partition for a group, handling epoch-based reassignment.
+ /// Returns (partitionId, wasNew, wasReassigned).
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private (int PartitionId, bool WasNew, bool WasReassigned) GetOrAssignPartitionWithEpoch(string groupName)
{
- if (GroupPartitions.TryGetValue(groupName, out var partition))
+ if (GroupPartitions.TryGetValue(groupName, out var existingAssignment))
{
- return partition;
+ // Check if assignment is from current epoch
+ if (existingAssignment.Epoch == _partitionEpoch)
+ {
+ return (existingAssignment.PartitionId, false, false);
+ }
+
+ // Stale epoch - check if partition would be different with current partition count
+ var newPartition = PartitionHelper.GetPartitionId(groupName, (uint)_currentPartitionCount);
+
+ if (newPartition == existingAssignment.PartitionId)
+ {
+ // Same partition, just update epoch
+ var updatedAssignment = PartitionAssignment.Create(existingAssignment.PartitionId, _partitionEpoch);
+ GroupPartitions[groupName] = updatedAssignment;
+ _stateDirty = true;
+ _logger.LogDebug(
+ "Updated group {GroupName} epoch from {OldEpoch} to {NewEpoch} (partition {Partition} unchanged)",
+ groupName, existingAssignment.Epoch, _partitionEpoch, existingAssignment.PartitionId);
+ return (existingAssignment.PartitionId, false, true);
+ }
+
+ // Partition changed due to scaling - reassign
+ var reassignment = PartitionAssignment.Create(newPartition, _partitionEpoch);
+ GroupPartitions[groupName] = reassignment;
+ _activePartitions.Add(newPartition);
+ _stateDirty = true;
+
+ _logger.LogInformation(
+ "Reassigned group {GroupName} from partition {OldPartition} (epoch {OldEpoch}) to partition {NewPartition} (epoch {NewEpoch}) due to scaling",
+ groupName, existingAssignment.PartitionId, existingAssignment.Epoch, newPartition, _partitionEpoch);
+
+ return (newPartition, false, true);
}
+ // New group - assign to partition with current epoch
var partitionCount = EnsurePartitionCapacity(GroupPartitions.Count + 1);
- partition = PartitionHelper.GetPartitionId(groupName, (uint)partitionCount);
- GroupPartitions[groupName] = partition;
+ var partition = PartitionHelper.GetPartitionId(groupName, (uint)partitionCount);
+ var assignment = PartitionAssignment.Create(partition, _partitionEpoch);
+
+ GroupPartitions[groupName] = assignment;
+ _activePartitions.Add(partition);
+ _stateDirty = true;
+
+ _logger.LogDebug(
+ "Assigned group {GroupName} to partition {Partition} (epoch {Epoch}, partitionCount={PartitionCount})",
+ groupName, partition, _partitionEpoch, partitionCount);
- _logger.LogDebug("Assigned group {GroupName} to partition {Partition} (partitionCount={PartitionCount})", groupName, partition, partitionCount);
- return partition;
+ return (partition, true, false);
}
private int EnsurePartitionCapacity(int prospectiveGroups)
@@ -221,34 +361,76 @@ private int EnsurePartitionCapacity(int prospectiveGroups)
if (desired > _currentPartitionCount)
{
_logger.LogInformation(
- "Increasing group partition count from {OldPartitionCount} to {NewPartitionCount} for {GroupCount} tracked groups.",
+ "Increasing group partition count from {OldPartitionCount} to {NewPartitionCount} (epoch {OldEpoch} -> {NewEpoch}) for {GroupCount} tracked groups.",
_currentPartitionCount,
desired,
+ _partitionEpoch,
+ _partitionEpoch + 1,
prospectiveGroups);
+
_currentPartitionCount = desired;
+ _partitionEpoch++;
_state.State.CurrentPartitionCount = _currentPartitionCount;
+ _state.State.PartitionEpoch = _partitionEpoch;
}
return _currentPartitionCount;
}
- private Dictionary GroupPartitions => _state.State.GroupPartitions!;
- private Dictionary GroupMembership => _state.State.GroupMembership!;
-
private void ReleaseGroup(string groupName)
{
var removedMembership = GroupMembership.Remove(groupName);
- var removedPartition = GroupPartitions.Remove(groupName);
+ var removedPartition = GroupPartitions.Remove(groupName, out var assignment);
+
+ if (removedPartition)
+ {
+ _stateDirty = true;
+ var partitionId = assignment.PartitionId;
+
+ // Check if any other groups are using this partition
+ var partitionStillActive = false;
+ foreach (var otherAssignment in GroupPartitions.Values)
+ {
+ if (otherAssignment.PartitionId == partitionId)
+ {
+ partitionStillActive = true;
+ break;
+ }
+ }
+
+ if (!partitionStillActive)
+ {
+ _activePartitions.Remove(partitionId);
+ }
+ }
if ((removedMembership || removedPartition) && GroupMembership.Count == 0 && _currentPartitionCount != _basePartitionCount)
{
- _logger.LogDebug("Resetting group partition count to base value {PartitionCount} as no active groups remain.", _basePartitionCount);
+ _logger.LogDebug("Resetting group partition count to base value {PartitionCount} and epoch to 1 as no active groups remain.", _basePartitionCount);
_currentPartitionCount = (int)_basePartitionCount;
_state.State.CurrentPartitionCount = _currentPartitionCount;
+ _partitionEpoch = 1;
+ _state.State.PartitionEpoch = _partitionEpoch;
+ _activePartitions.Clear();
}
}
- private static Dictionary EnsureOrdinalDictionary(Dictionary? dictionary)
+ private static Dictionary EnsureOrdinalDictionary(Dictionary? dictionary)
+ {
+ if (dictionary is null)
+ {
+ return new Dictionary(StringComparer.Ordinal);
+ }
+
+ if (dictionary.Comparer == StringComparer.Ordinal)
+ {
+ return dictionary;
+ }
+
+ return new Dictionary(dictionary, StringComparer.Ordinal);
+ }
+
+ private static Dictionary EnsureOrdinalMembershipDictionary(Dictionary? dictionary)
{
if (dictionary is null)
{
diff --git a/ManagedCode.Orleans.SignalR.Server/SignalRGroupGrain.cs b/ManagedCode.Orleans.SignalR.Server/SignalRGroupGrain.cs
index 4a47bd9..efd60d9 100644
--- a/ManagedCode.Orleans.SignalR.Server/SignalRGroupGrain.cs
+++ b/ManagedCode.Orleans.SignalR.Server/SignalRGroupGrain.cs
@@ -36,20 +36,21 @@ public override async Task OnActivateAsync(CancellationToken cancellationToken)
await base.OnActivateAsync(cancellationToken);
}
- public async Task SendToGroup(HubMessage message)
+ public Task SendToGroup(HubMessage message)
{
Logs.SendToGroup(Logger, nameof(SignalRGroupGrain), this.GetPrimaryKeyString());
if (LiveObservers.Count > 0)
{
DispatchToLiveObservers(LiveObservers.Values, message);
- return;
+ return Task.CompletedTask;
}
- await Task.Run(() => ObserverManager.Notify(s => s.OnNextAsync(message)));
+ ObserverManager.Notify(s => s.OnNextAsync(message));
+ return Task.CompletedTask;
}
- public async Task SendToGroupExcept(HubMessage message, string[] excludedConnectionIds)
+ public Task SendToGroupExcept(HubMessage message, string[] excludedConnectionIds)
{
Logs.SendToGroupExcept(Logger, nameof(SignalRGroupGrain), this.GetPrimaryKeyString(), excludedConnectionIds);
@@ -58,7 +59,7 @@ public async Task SendToGroupExcept(HubMessage message, string[] excludedConnect
var excluded = new HashSet(excludedConnectionIds, StringComparer.Ordinal);
var targets = LiveObservers.Where(kvp => !excluded.Contains(kvp.Key)).Select(kvp => kvp.Value);
DispatchToLiveObservers(targets, message);
- return;
+ return Task.CompletedTask;
}
var hashSet = new HashSet();
@@ -70,8 +71,9 @@ public async Task SendToGroupExcept(HubMessage message, string[] excludedConnect
}
}
- await Task.Run(() => ObserverManager.Notify(s => s.OnNextAsync(message),
- connection => !hashSet.Contains(connection.GetPrimaryKeyString())));
+ ObserverManager.Notify(s => s.OnNextAsync(message),
+ connection => !hashSet.Contains(connection.GetPrimaryKeyString()));
+ return Task.CompletedTask;
}
public async Task AddConnection(string connectionId, ISignalRObserver observer)
@@ -118,11 +120,11 @@ public override async Task OnDeactivateAsync(DeactivationReason reason, Cancella
if (!hasConnections)
{
- await stateStorage.ClearStateAsync(cancellationToken);
+ await stateStorage.ClearStateSafeAsync(cancellationToken);
}
else
{
- await stateStorage.WriteStateAsync(cancellationToken);
+ await stateStorage.WriteStateSafeAsync(cancellationToken);
}
}
diff --git a/ManagedCode.Orleans.SignalR.Server/SignalRGroupPartitionGrain.cs b/ManagedCode.Orleans.SignalR.Server/SignalRGroupPartitionGrain.cs
index cbaa1a6..b6c042c 100644
--- a/ManagedCode.Orleans.SignalR.Server/SignalRGroupPartitionGrain.cs
+++ b/ManagedCode.Orleans.SignalR.Server/SignalRGroupPartitionGrain.cs
@@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
-using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using ManagedCode.Orleans.SignalR.Core.Config;
@@ -38,7 +37,7 @@ public override async Task OnActivateAsync(CancellationToken cancellationToken)
await base.OnActivateAsync(cancellationToken);
}
- public async Task SendToGroups(HubMessage message, string[] groupNames)
+ public Task SendToGroups(HubMessage message, string[] groupNames)
{
Logger.LogDebug("SendToGroups invoked for partition {PartitionId} with groups {Groups} (keepAlive={KeepEachConnectionAlive}, liveObservers={LiveObserversCount}, trackedConnections={TrackedConnectionCount})",
this.GetPrimaryKeyLong(),
@@ -51,17 +50,18 @@ public async Task SendToGroups(HubMessage message, string[] groupNames)
{
var targetConnections = CollectConnectionIds(groupNames, excludedConnections: null);
DispatchToLiveObservers(GetLiveObservers(targetConnections), message);
- return;
+ return Task.CompletedTask;
}
var targetObservers = CollectObservers(groupNames, excludedConnections: null);
- await Task.Run(() => ObserverManager.Notify(
+ ObserverManager.Notify(
observer => observer.OnNextAsync(message),
- observer => targetObservers.Contains(observer.GetPrimaryKeyString())));
+ observer => targetObservers.Contains(observer.GetPrimaryKeyString()));
+ return Task.CompletedTask;
}
- public async Task SendToGroupsExcept(HubMessage message, string[] groupNames, string[] excludedConnectionIds)
+ public Task SendToGroupsExcept(HubMessage message, string[] groupNames, string[] excludedConnectionIds)
{
Logger.LogDebug("SendToGroupsExcept invoked for partition {PartitionId} with groups {Groups}, excluded {Excluded} (keepAlive={KeepEachConnectionAlive}, liveObservers={LiveObserversCount}, trackedConnections={TrackedConnectionCount})",
this.GetPrimaryKeyLong(),
@@ -75,15 +75,16 @@ public async Task SendToGroupsExcept(HubMessage message, string[] groupNames, st
{
var targetConnections = CollectConnectionIds(groupNames, new HashSet(excludedConnectionIds, StringComparer.Ordinal));
DispatchToLiveObservers(GetLiveObservers(targetConnections), message);
- return;
+ return Task.CompletedTask;
}
var excluded = new HashSet(excludedConnectionIds);
var targetObservers = CollectObservers(groupNames, excluded);
- await Task.Run(() => ObserverManager.Notify(
+ ObserverManager.Notify(
observer => observer.OnNextAsync(message),
- observer => targetObservers.Contains(observer.GetPrimaryKeyString())));
+ observer => targetObservers.Contains(observer.GetPrimaryKeyString()));
+ return Task.CompletedTask;
}
public async Task AddConnection(string connectionId, ISignalRObserver observer)
@@ -239,10 +240,10 @@ public override Task OnDeactivateAsync(DeactivationReason reason, CancellationTo
if (!hasState)
{
- return state.ClearStateAsync(cancellationToken);
+ return state.ClearStateSafeAsync(cancellationToken);
}
- return state.WriteStateAsync(cancellationToken);
+ return state.WriteStateSafeAsync(cancellationToken);
}
private HashSet CollectObservers(IEnumerable