|
| 1 | +using System.Runtime.InteropServices; |
| 2 | +using System.Runtime.Intrinsics; |
| 3 | +using SixLabors.ImageSharp.PixelFormats; |
| 4 | + |
| 5 | +namespace VerifyTests; |
| 6 | + |
| 7 | +public static class SsimComparer |
| 8 | +{ |
| 9 | + const double k1 = 0.01; |
| 10 | + const double k2 = 0.03; |
| 11 | + const double l = 255.0; |
| 12 | + const double c1 = k1 * l * k1 * l; |
| 13 | + const double c2 = k2 * l * k2 * l; |
| 14 | + |
| 15 | + public static double Calculate(Stream received, Stream verified) |
| 16 | + { |
| 17 | + using var img1 = Image.Load<Rgba32>(received); |
| 18 | + using var img2 = Image.Load<Rgba32>(verified); |
| 19 | + |
| 20 | + if (img1.Width != img2.Width || |
| 21 | + img1.Height != img2.Height) |
| 22 | + { |
| 23 | + return 0; |
| 24 | + } |
| 25 | + |
| 26 | + var width = img1.Width; |
| 27 | + var height = img1.Height; |
| 28 | + var pixelCount = (double) (width * height); |
| 29 | + |
| 30 | + double sumR1 = 0, sumG1 = 0, sumB1 = 0; |
| 31 | + double sumR2 = 0, sumG2 = 0, sumB2 = 0; |
| 32 | + double sumR1Sq = 0, sumG1Sq = 0, sumB1Sq = 0; |
| 33 | + double sumR2Sq = 0, sumG2Sq = 0, sumB2Sq = 0; |
| 34 | + double sumR12 = 0, sumG12 = 0, sumB12 = 0; |
| 35 | + |
| 36 | + img1.ProcessPixelRows(img2, (acc1, acc2) => |
| 37 | + { |
| 38 | + for (var y = 0; y < height; y++) |
| 39 | + { |
| 40 | + var row1 = MemoryMarshal.Cast<Rgba32, uint>(acc1.GetRowSpan(y)); |
| 41 | + var row2 = MemoryMarshal.Cast<Rgba32, uint>(acc2.GetRowSpan(y)); |
| 42 | + var x = 0; |
| 43 | + |
| 44 | + if (Vector256.IsHardwareAccelerated) |
| 45 | + { |
| 46 | + x = AccumulateVector256( |
| 47 | + row1, row2, width, |
| 48 | + ref sumR1, ref sumG1, ref sumB1, |
| 49 | + ref sumR2, ref sumG2, ref sumB2, |
| 50 | + ref sumR1Sq, ref sumG1Sq, ref sumB1Sq, |
| 51 | + ref sumR2Sq, ref sumG2Sq, ref sumB2Sq, |
| 52 | + ref sumR12, ref sumG12, ref sumB12); |
| 53 | + } |
| 54 | + else if (Vector128.IsHardwareAccelerated) |
| 55 | + { |
| 56 | + x = AccumulateVector128( |
| 57 | + row1, row2, width, |
| 58 | + ref sumR1, ref sumG1, ref sumB1, |
| 59 | + ref sumR2, ref sumG2, ref sumB2, |
| 60 | + ref sumR1Sq, ref sumG1Sq, ref sumB1Sq, |
| 61 | + ref sumR2Sq, ref sumG2Sq, ref sumB2Sq, |
| 62 | + ref sumR12, ref sumG12, ref sumB12); |
| 63 | + } |
| 64 | + |
| 65 | + for (; x < width; x++) |
| 66 | + { |
| 67 | + double r1 = (byte) row1[x], g1 = (byte) (row1[x] >> 8), b1 = (byte) (row1[x] >> 16); |
| 68 | + double r2 = (byte) row2[x], g2 = (byte) (row2[x] >> 8), b2 = (byte) (row2[x] >> 16); |
| 69 | + |
| 70 | + sumR1 += r1; sumG1 += g1; sumB1 += b1; |
| 71 | + sumR2 += r2; sumG2 += g2; sumB2 += b2; |
| 72 | + sumR1Sq += r1 * r1; sumG1Sq += g1 * g1; sumB1Sq += b1 * b1; |
| 73 | + sumR2Sq += r2 * r2; sumG2Sq += g2 * g2; sumB2Sq += b2 * b2; |
| 74 | + sumR12 += r1 * r2; sumG12 += g1 * g2; sumB12 += b1 * b2; |
| 75 | + } |
| 76 | + } |
| 77 | + }); |
| 78 | + |
| 79 | + var ssimR = CalculateChannel(pixelCount, sumR1, sumR2, sumR1Sq, sumR2Sq, sumR12); |
| 80 | + var ssimG = CalculateChannel(pixelCount, sumG1, sumG2, sumG1Sq, sumG2Sq, sumG12); |
| 81 | + var ssimB = CalculateChannel(pixelCount, sumB1, sumB2, sumB1Sq, sumB2Sq, sumB12); |
| 82 | + |
| 83 | + return (ssimR + ssimG + ssimB) / 3.0; |
| 84 | + } |
| 85 | + |
| 86 | + static int AccumulateVector256( |
| 87 | + ReadOnlySpan<uint> row1, ReadOnlySpan<uint> row2, int width, |
| 88 | + ref double sumR1, ref double sumG1, ref double sumB1, |
| 89 | + ref double sumR2, ref double sumG2, ref double sumB2, |
| 90 | + ref double sumR1Sq, ref double sumG1Sq, ref double sumB1Sq, |
| 91 | + ref double sumR2Sq, ref double sumG2Sq, ref double sumB2Sq, |
| 92 | + ref double sumR12, ref double sumG12, ref double sumB12) |
| 93 | + { |
| 94 | + var vR1 = Vector256<float>.Zero; var vG1 = Vector256<float>.Zero; var vB1 = Vector256<float>.Zero; |
| 95 | + var vR2 = Vector256<float>.Zero; var vG2 = Vector256<float>.Zero; var vB2 = Vector256<float>.Zero; |
| 96 | + var vR1Sq = Vector256<float>.Zero; var vG1Sq = Vector256<float>.Zero; var vB1Sq = Vector256<float>.Zero; |
| 97 | + var vR2Sq = Vector256<float>.Zero; var vG2Sq = Vector256<float>.Zero; var vB2Sq = Vector256<float>.Zero; |
| 98 | + var vR12 = Vector256<float>.Zero; var vG12 = Vector256<float>.Zero; var vB12 = Vector256<float>.Zero; |
| 99 | + var mask = Vector256.Create(0x000000FFu); |
| 100 | + ref var ref1 = ref MemoryMarshal.GetReference(row1); |
| 101 | + ref var ref2 = ref MemoryMarshal.GetReference(row2); |
| 102 | + var x = 0; |
| 103 | + |
| 104 | + for (; x <= width - Vector256<uint>.Count; x += Vector256<uint>.Count) |
| 105 | + { |
| 106 | + var p1 = Vector256.LoadUnsafe(ref ref1, (nuint) x); |
| 107 | + var p2 = Vector256.LoadUnsafe(ref ref2, (nuint) x); |
| 108 | + |
| 109 | + var r1 = Vector256.ConvertToSingle((p1 & mask).AsInt32()); |
| 110 | + var g1 = Vector256.ConvertToSingle((Vector256.ShiftRightLogical(p1, 8) & mask).AsInt32()); |
| 111 | + var b1 = Vector256.ConvertToSingle((Vector256.ShiftRightLogical(p1, 16) & mask).AsInt32()); |
| 112 | + var r2 = Vector256.ConvertToSingle((p2 & mask).AsInt32()); |
| 113 | + var g2 = Vector256.ConvertToSingle((Vector256.ShiftRightLogical(p2, 8) & mask).AsInt32()); |
| 114 | + var b2 = Vector256.ConvertToSingle((Vector256.ShiftRightLogical(p2, 16) & mask).AsInt32()); |
| 115 | + |
| 116 | + vR1 += r1; vG1 += g1; vB1 += b1; |
| 117 | + vR2 += r2; vG2 += g2; vB2 += b2; |
| 118 | + vR1Sq += r1 * r1; vG1Sq += g1 * g1; vB1Sq += b1 * b1; |
| 119 | + vR2Sq += r2 * r2; vG2Sq += g2 * g2; vB2Sq += b2 * b2; |
| 120 | + vR12 += r1 * r2; vG12 += g1 * g2; vB12 += b1 * b2; |
| 121 | + } |
| 122 | + |
| 123 | + sumR1 += Vector256.Sum(vR1); sumG1 += Vector256.Sum(vG1); sumB1 += Vector256.Sum(vB1); |
| 124 | + sumR2 += Vector256.Sum(vR2); sumG2 += Vector256.Sum(vG2); sumB2 += Vector256.Sum(vB2); |
| 125 | + sumR1Sq += Vector256.Sum(vR1Sq); sumG1Sq += Vector256.Sum(vG1Sq); sumB1Sq += Vector256.Sum(vB1Sq); |
| 126 | + sumR2Sq += Vector256.Sum(vR2Sq); sumG2Sq += Vector256.Sum(vG2Sq); sumB2Sq += Vector256.Sum(vB2Sq); |
| 127 | + sumR12 += Vector256.Sum(vR12); sumG12 += Vector256.Sum(vG12); sumB12 += Vector256.Sum(vB12); |
| 128 | + return x; |
| 129 | + } |
| 130 | + |
| 131 | + static int AccumulateVector128( |
| 132 | + ReadOnlySpan<uint> row1, ReadOnlySpan<uint> row2, int width, |
| 133 | + ref double sumR1, ref double sumG1, ref double sumB1, |
| 134 | + ref double sumR2, ref double sumG2, ref double sumB2, |
| 135 | + ref double sumR1Sq, ref double sumG1Sq, ref double sumB1Sq, |
| 136 | + ref double sumR2Sq, ref double sumG2Sq, ref double sumB2Sq, |
| 137 | + ref double sumR12, ref double sumG12, ref double sumB12) |
| 138 | + { |
| 139 | + var vR1 = Vector128<float>.Zero; var vG1 = Vector128<float>.Zero; var vB1 = Vector128<float>.Zero; |
| 140 | + var vR2 = Vector128<float>.Zero; var vG2 = Vector128<float>.Zero; var vB2 = Vector128<float>.Zero; |
| 141 | + var vR1Sq = Vector128<float>.Zero; var vG1Sq = Vector128<float>.Zero; var vB1Sq = Vector128<float>.Zero; |
| 142 | + var vR2Sq = Vector128<float>.Zero; var vG2Sq = Vector128<float>.Zero; var vB2Sq = Vector128<float>.Zero; |
| 143 | + var vR12 = Vector128<float>.Zero; var vG12 = Vector128<float>.Zero; var vB12 = Vector128<float>.Zero; |
| 144 | + var mask = Vector128.Create(0x000000FFu); |
| 145 | + ref var ref1 = ref MemoryMarshal.GetReference(row1); |
| 146 | + ref var ref2 = ref MemoryMarshal.GetReference(row2); |
| 147 | + var x = 0; |
| 148 | + |
| 149 | + for (; x <= width - Vector128<uint>.Count; x += Vector128<uint>.Count) |
| 150 | + { |
| 151 | + var p1 = Vector128.LoadUnsafe(ref ref1, (nuint) x); |
| 152 | + var p2 = Vector128.LoadUnsafe(ref ref2, (nuint) x); |
| 153 | + |
| 154 | + var r1 = Vector128.ConvertToSingle((p1 & mask).AsInt32()); |
| 155 | + var g1 = Vector128.ConvertToSingle((Vector128.ShiftRightLogical(p1, 8) & mask).AsInt32()); |
| 156 | + var b1 = Vector128.ConvertToSingle((Vector128.ShiftRightLogical(p1, 16) & mask).AsInt32()); |
| 157 | + var r2 = Vector128.ConvertToSingle((p2 & mask).AsInt32()); |
| 158 | + var g2 = Vector128.ConvertToSingle((Vector128.ShiftRightLogical(p2, 8) & mask).AsInt32()); |
| 159 | + var b2 = Vector128.ConvertToSingle((Vector128.ShiftRightLogical(p2, 16) & mask).AsInt32()); |
| 160 | + |
| 161 | + vR1 += r1; vG1 += g1; vB1 += b1; |
| 162 | + vR2 += r2; vG2 += g2; vB2 += b2; |
| 163 | + vR1Sq += r1 * r1; vG1Sq += g1 * g1; vB1Sq += b1 * b1; |
| 164 | + vR2Sq += r2 * r2; vG2Sq += g2 * g2; vB2Sq += b2 * b2; |
| 165 | + vR12 += r1 * r2; vG12 += g1 * g2; vB12 += b1 * b2; |
| 166 | + } |
| 167 | + |
| 168 | + sumR1 += Vector128.Sum(vR1); sumG1 += Vector128.Sum(vG1); sumB1 += Vector128.Sum(vB1); |
| 169 | + sumR2 += Vector128.Sum(vR2); sumG2 += Vector128.Sum(vG2); sumB2 += Vector128.Sum(vB2); |
| 170 | + sumR1Sq += Vector128.Sum(vR1Sq); sumG1Sq += Vector128.Sum(vG1Sq); sumB1Sq += Vector128.Sum(vB1Sq); |
| 171 | + sumR2Sq += Vector128.Sum(vR2Sq); sumG2Sq += Vector128.Sum(vG2Sq); sumB2Sq += Vector128.Sum(vB2Sq); |
| 172 | + sumR12 += Vector128.Sum(vR12); sumG12 += Vector128.Sum(vG12); sumB12 += Vector128.Sum(vB12); |
| 173 | + return x; |
| 174 | + } |
| 175 | + |
| 176 | + static double CalculateChannel(double n, double sum1, double sum2, double sum1Sq, double sum2Sq, double sum12) |
| 177 | + { |
| 178 | + var mu1 = sum1 / n; |
| 179 | + var mu2 = sum2 / n; |
| 180 | + var sigma1Sq = sum1Sq / n - mu1 * mu1; |
| 181 | + var sigma2Sq = sum2Sq / n - mu2 * mu2; |
| 182 | + var sigma12 = sum12 / n - mu1 * mu2; |
| 183 | + |
| 184 | + return (2 * mu1 * mu2 + c1) * (2 * sigma12 + c2) / |
| 185 | + ((mu1 * mu1 + mu2 * mu2 + c1) * (sigma1Sq + sigma2Sq + c2)); |
| 186 | + } |
| 187 | +} |
0 commit comments