diff --git a/API/Services/Account/AccountService.cs b/API/Services/Account/AccountService.cs index 019099c0..054d6d7e 100644 --- a/API/Services/Account/AccountService.cs +++ b/API/Services/Account/AccountService.cs @@ -259,10 +259,15 @@ public async Task x.Email == lowercaseUsernameOrEmail || x.Name == lowercaseUsernameOrEmail, cancellationToken); if (user is null) { - // TODO: Set appropriate time to match password hashing time, preventing timing attacks - await Task.Delay(100, cancellationToken); + await HashingUtils.VerifyPasswordFake(); return new NotFound(); } + + if (!await CheckPassword(password, user)) + { + return new NotFound(); + } + if (user.ActivatedAt is null) { return new AccountNotActivated(); @@ -272,7 +277,6 @@ public async Task (long)(ms * TimeSpan.TicksPerMillisecond); + + private static (double mean, double std) MeanStd(IEnumerable samples) + { + var arr = samples.Select(x => (double)x).ToArray(); + double n = arr.Length; + double mean = arr.Average(); + if (n <= 1) return (mean, 0); + + double sumSq = arr.Sum(x => x * x); + double variance = (sumSq - n * mean * mean) / (n - 1); + return (mean, Math.Sqrt(Math.Max(0, variance))); + } + + private const double Eps = 1e-7; + + // --- constructor & basic stats --- + + [Test] + public async Task Ctor_SeedsWithDefaultMs_StatsMatch() + { + var seedMs = 12; + var emu = new LatencyEmulator(capacity: 8, defaultMs: seedMs); + + var (mean, std) = emu.GetStats(); + await Assert.That(mean).IsEqualTo(seedMs).Within(0.5); + await Assert.That(std).IsEqualTo(0); + } + + [Test] + public void Ctor_DefaultMs_Negative_Throws() + { + Assert.Throws(() => new LatencyEmulator(capacity: 8, -1)); + } + + [Test] + public void Ctor_Capacity_OneOrLess_Throws() + { + Assert.Throws(() => new LatencyEmulator(1, 0)); + Assert.Throws(() => new LatencyEmulator(0, 0)); + Assert.Throws(() => new LatencyEmulator(-5, 0)); + } + + // --- Record: input validation --- + + [Test] + public void Record_Negative_Throws() + { + var emu = new LatencyEmulator(capacity: 4, defaultMs: 0); + Assert.Throws(() => emu.Record(-1)); + } + + // --- Record: growth phase (no eviction) --- + + [Test] + public async Task Record_Growing_NoEvictions_StatsMatchAllSamples() + { + var emu = new LatencyEmulator(capacity: 8, defaultMs: 0); + + // Add positive tick samples + long[] add = [ 1, 3, 5 ]; + foreach (var t in add) emu.Record(Ms(t)); + + // window should contain [0, 1 ms, 3 ms, 5 ms] (ticks) + var expected = new List { 0 }; + expected.AddRange(add); + + var (expMean, expStd) = MeanStd(expected); + var (mean, std) = emu.GetStats(); + + await Assert.That(mean).IsEqualTo(expMean).Within(Eps); + await Assert.That(std).IsEqualTo(expStd).Within(Eps); + } + + // --- Record: steady state (with eviction) --- + + [Test] + public async Task Record_EvictsOldest_MaintainsSlidingWindow() + { + // capacity 3, seed 0 ms => window starts [0] + var emu = new LatencyEmulator(capacity: 3, defaultMs: 0); + + // Fill to capacity: [0, 10 ms, 20 ms] + emu.Record(Ms(10)); + emu.Record(Ms(20)); + + var (m1, s1) = emu.GetStats(); + long[] expected1 = [ 0, 10, 20 ]; + var (expM1, expS1) = MeanStd(expected1); + await Assert.That(m1).IsEqualTo(expM1).Within(Eps); + await Assert.That(s1).IsEqualTo(expS1).Within(Eps); + + // Next insert 30 ms => evict the oldest (0), new window [10,20,30] + emu.Record(Ms(30)); + var (m2, s2) = emu.GetStats(); + long[] expected2 = [ 10, 20, 30 ]; + var (expM2, expS2) = MeanStd(expected2); + await Assert.That(m2).IsEqualTo(expM2).Within(Eps); + await Assert.That(s2).IsEqualTo(expS2).Within(Eps); + + // Next insert 40 ms => window [20,30,40] + emu.Record(Ms(40)); + var (m3, s3) = emu.GetStats(); + long[] expected3 = [ 20, 30, 40 ]; + var (expM3, expS3) = MeanStd(expected3); + await Assert.That(m3).IsEqualTo(expM3).Within(Eps); + await Assert.That(s3).IsEqualTo(expS3).Within(Eps); + } + + // --- GetFake() behavior --- + + [Test] + public async Task GetFake_WhenStdZero_ReturnsMeanExactly() + { + // Make all samples identical so std==0 + var emu = new LatencyEmulator(capacity: 5, defaultMs: 7); + var same = Ms(7.0); + emu.Record(same); + emu.Record(same); + emu.Record(same); + emu.Record(same); + + var (mean, std) = emu.GetStats(); + await Assert.That(std).IsEqualTo(0); + + // Without noise, fake should be exactly the mean (with rounding) + var fake = emu.GetFake(); + await Assert.That(fake.TotalMilliseconds).IsEqualTo(mean).Within(Eps); + } + + [Test] + public async Task GetFake_NonZeroStd_NonNegative_AndVaries() + { + var emu = new LatencyEmulator(capacity: 16, defaultMs: 0); + + // Create a spread so std > 0 + foreach (var ms in new[] { 1, 2, 3, 5, 8, 13, 21, 34 }) emu.Record(Ms(ms)); + + var (mean, std) = emu.GetStats(); + await Assert.That(std).IsGreaterThan(0); + + // Gather many samples; all must be non-negative, + // and at least one should differ from a rounded mean. + var fakes = new List(); + for (int i = 0; i < 200; i++) fakes.Add(emu.GetFake().Ticks); + + await Assert.That(fakes).DoesNotContain(x => x < 0); + await Assert.That(fakes).ContainsOnly(x => Math.Abs(x - Math.Round(mean)) > 0); + } + + // --- Numerical stability & precision --- + + [Test] + public async Task Stats_UnbiasedSampleStd_MatchesReference() + { + var emu = new LatencyEmulator(capacity: 8, defaultMs: 0); + long[] vals = [ Ms(10), Ms(20), Ms(30), Ms(40), Ms(50) ]; + foreach (var v in vals) emu.Record(Ms(v)); + + // Window: [0,10,20,30,40,50] (n=6) all in ms converted to ticks + var expected = new long[] { 0 }.Concat(vals).ToArray(); + var (expMean, expStd) = MeanStd(expected); + + var (mean, std) = emu.GetStats(); + await Assert.That(mean).IsEqualTo(expMean).Within(Eps); + await Assert.That(std).IsEqualTo(expStd).Within(Eps); + } + + // --- Concurrency sanity check (no exceptions, stats sane) --- + + [Test] + public async Task Record_IsThreadSafe_Sanity() + { + var emu = new LatencyEmulator(capacity: 128, defaultMs: 1); + + var tasks = Enumerable.Range(0, Environment.ProcessorCount) + .Select(i => Task.Run(() => + { + var rnd = new Random(i * 7919 + 17); + for (int k = 0; k < 5000; k++) + { + // Generate strictly positive tick values (~ up to 10 ms) + // Ensure >= 1 tick to satisfy ThrowIfNegativeOrZero. + long ticks = Math.Max(1, TimeSpan.FromMilliseconds(rnd.NextDouble() * 10).Ticks); + emu.Record(ticks); + } + })) + .ToArray(); + + await Task.WhenAll(tasks); + + var (mean, std) = emu.GetStats(); + // Just make sure we didn’t corrupt numeric state. + await Assert.That(double.IsNaN(mean) || double.IsInfinity(mean)).IsFalse(); + await Assert.That(double.IsNaN(std) || double.IsInfinity(std)).IsFalse(); + await Assert.That(mean).IsGreaterThan(0); + await Assert.That(std).IsGreaterThanOrEqualTo(0); + } +} diff --git a/Common/Utils/HashingUtils.cs b/Common/Utils/HashingUtils.cs index 3e33788a..7387c3b2 100644 --- a/Common/Utils/HashingUtils.cs +++ b/Common/Utils/HashingUtils.cs @@ -1,4 +1,5 @@ using System.Buffers; +using System.Diagnostics; using System.Security.Cryptography; using System.Text; using BCrypt.Net; @@ -11,6 +12,8 @@ public static class HashingUtils private const string BCryptPrefix = "bcrypt"; private const string Pbkdf2Prefix = "pbkdf2"; private const HashType BCryptHashType = HashType.SHA512; + + private static readonly LatencyEmulator VerifyTiming = new(200, 100); public readonly record struct VerifyHashResult(bool Verified, bool NeedsRehash); private static readonly VerifyHashResult VerifyHashFailureResult = new(false, false); @@ -73,6 +76,7 @@ public static string HashPassword(string password) { return $"{BCryptPrefix}:{BCrypt.Net.BCrypt.EnhancedHashPassword(password, BCryptHashType)}"; } + public static VerifyHashResult VerifyPassword(string password, string combinedHash) { int index = combinedHash.IndexOf(':'); @@ -82,9 +86,14 @@ public static VerifyHashResult VerifyPassword(string password, string combinedHa if (algorithm == PasswordHashingAlgorithm.BCrypt) { + var start = Stopwatch.GetTimestamp(); + var verified = BCrypt.Net.BCrypt.EnhancedVerify(password, combinedHash[(index + 1)..], BCryptHashType); + var stop = Stopwatch.GetTimestamp(); + VerifyTiming.Record(stop - start); + return new VerifyHashResult { - Verified = BCrypt.Net.BCrypt.EnhancedVerify(password, combinedHash[(index + 1)..], BCryptHashType), + Verified = verified, NeedsRehash = false }; } @@ -103,6 +112,11 @@ public static VerifyHashResult VerifyPassword(string password, string combinedHa return VerifyHashFailureResult; } + public static Task VerifyPasswordFake() + { + return Task.Delay(VerifyTiming.GetFake()); + } + public static string HashToken(string token) { // BE CAREFUL, changing this will break leaked token reporting. diff --git a/Common/Utils/LatencyEmulator.cs b/Common/Utils/LatencyEmulator.cs new file mode 100644 index 00000000..ace3d52a --- /dev/null +++ b/Common/Utils/LatencyEmulator.cs @@ -0,0 +1,130 @@ +namespace OpenShock.Common.Utils; + +public sealed class LatencyEmulator +{ + // Use object for broad framework compat; replace with `Lock` if desired. + private readonly Lock _gate = new(); + private readonly long[] _buf; + private int _count; + private int _head; + + // Use double to prevent overflow and improve precision of stats. + private double _sum; + private double _sumSq; + + /// + /// Sliding window of timing samples (stored as ticks). + /// Seeds the window with one sample = max(defaultMs, 0). + /// + public LatencyEmulator(int capacity, long defaultMs) + { + if (capacity <= 1) + throw new ArgumentOutOfRangeException(nameof(capacity), "Capacity must be > 1."); + + ArgumentOutOfRangeException.ThrowIfNegative(defaultMs); + + _buf = new long[capacity]; + + long ticks = defaultMs * TimeSpan.TicksPerMillisecond; + _buf[0] = ticks; + _count = 1; + _head = 1; + + _sum = ticks; + _sumSq = (double)ticks * ticks; + } + + /// + /// Record a timing sample in TICKS (not milliseconds). + /// + public void Record(long elapsedTicks) + { + ArgumentOutOfRangeException.ThrowIfNegative(elapsedTicks); + + lock (_gate) + { + if (_count < _buf.Length) + { + // growing phase: no evictions + _buf[_head] = elapsedTicks; + _count++; + } + else + { + // steady state: evict oldest at _head, then insert + long old = _buf[_head]; + _sum -= old; + _sumSq -= (double)old * old; + + _buf[_head] = elapsedTicks; + } + + _sum += elapsedTicks; + _sumSq += (double)elapsedTicks * elapsedTicks; + + _head = (_head + 1) % _buf.Length; + } + } + + /// + /// Return a simulated timing using current window mean ± Gaussian noise (as a TimeSpan). + /// Clamped to non-negative ticks. + /// + public TimeSpan GetFake() + { + lock (_gate) + { + var (mean, std) = MeanStdUnsafe(); + double noise = std > 0 ? NextGaussian(0, std) : 0; + double value = Math.Max(mean + noise, 0); // clamp at 0 + return TimeSpan.FromTicks((long)Math.Round(value)); + } + } + + /// + /// Returns (meanMs, stdDevMs) + /// + public (double mean, double std) GetStats() + { + double mean, std; + + lock (_gate) + { + (mean, std) = MeanStdUnsafe(); + } + + return (mean / TimeSpan.TicksPerMillisecond, std / TimeSpan.TicksPerMillisecond); + } + + // --- helpers --- + + // Uses maintained sums for O(1) stats + private (double mean, double std) MeanStdUnsafe() + { + switch (_count) + { + case 0: + return (0, 0); + case 1: + double v = _buf[0]; + return (v, 0); + } + + double n = _count; + double mean = _sum / n; + // Unbiased sample variance + double variance = (_sumSq - n * mean * mean) / (n - 1); + double std = Math.Sqrt(Math.Max(variance, 0)); + return (mean, std); + } + + private static double NextGaussian(double mean, double stdDev) + { + // Box–Muller with Random.Shared + double u1 = 1.0 - Random.Shared.NextDouble(); // (0,1] + double u2 = 1.0 - Random.Shared.NextDouble(); + double mag = Math.Sqrt(-2.0 * Math.Log(u1)); + double z0 = mag * Math.Cos(2.0 * Math.PI * u2); + return mean + z0 * stdDev; + } +}