|
| 1 | +package com.rae.formicapi.fondation.math.solvers; |
| 2 | + |
| 3 | +import com.rae.formicapi.fondation.math.operators.Matrix; |
| 4 | + |
| 5 | +public class LeastSquare2 { |
| 6 | + |
| 7 | + /** |
| 8 | + * Convenience overload: zero initial guess, allocates working buffers internally. |
| 9 | + * Use only when the caller has no long-lived matrix to cache buffers on — prefer |
| 10 | + * the pre-allocated overload for any hot path. |
| 11 | + */ |
| 12 | + public static double[] solve(Matrix A, double[] b, int maxIter, float tol) { |
| 13 | + int m = A.cols(); |
| 14 | + int n = A.rows(); |
| 15 | + return solve(A, b, maxIter, tol, |
| 16 | + new double[m], new double[m], new double[m], new double[m], new double[m], new double[n]); |
| 17 | + } |
| 18 | + |
| 19 | + /** |
| 20 | + * Solve {@code Ax = b} in the least-squares sense using CG on the normal equations |
| 21 | + * {@code AᵀA x = Aᵀb}, with caller-supplied working buffers. |
| 22 | + * |
| 23 | + * <p>Pass pre-allocated arrays from a long-lived object (e.g. {@code PhysicsMatrix}) |
| 24 | + * to avoid allocating ~4 × n doubles on every call. At 376 832 voxels and 20 ticks/s |
| 25 | + * the naive version allocates ~240 MB/s; this overload allocates nothing after warmup. |
| 26 | + * |
| 27 | + * <p>The contents of all four working arrays are overwritten on every call. |
| 28 | + * Their values between calls are undefined and must not be read by the caller. |
| 29 | + * |
| 30 | + * @param A input matrix (square or rectangular) |
| 31 | + * @param b right-hand side, length {@code A.rows()} |
| 32 | + * @param maxIter maximum CG iterations |
| 33 | + * @param tol convergence tolerance on the residual norm |
| 34 | + * @param r pre-allocated residual buffer, length ≥ {@code A.cols()} |
| 35 | + * @param p pre-allocated search-direction buffer, length ≥ {@code A.cols()} |
| 36 | + * @param Ap pre-allocated AᵀA·p buffer, length ≥ {@code A.cols()} |
| 37 | + * @param temp pre-allocated A·p intermediate buffer, length ≥ {@code A.rows()} |
| 38 | + * @return solution vector x (a new array of length {@code A.cols()}) |
| 39 | + */ |
| 40 | + public static double[] solve(Matrix A, double[] b, int maxIter, float tol, |
| 41 | + double[] initialX, double[] r, double[] p, double[] Atb, double[] Ap, double[] temp) { |
| 42 | + int n = A.rows(); |
| 43 | + int m = A.cols(); |
| 44 | + |
| 45 | + if (b.length != n) |
| 46 | + throw new IllegalArgumentException( |
| 47 | + "RHS length (" + b.length + ") != matrix rows (" + n + ")"); |
| 48 | + if (r.length < m || p.length < m || Ap.length < m) |
| 49 | + throw new IllegalArgumentException( |
| 50 | + "Working buffers r/p/Ap must have length >= A.cols() = " + m); |
| 51 | + if (temp.length < n) |
| 52 | + throw new IllegalArgumentException( |
| 53 | + "Working buffer temp must have length >= A.rows() = " + n); |
| 54 | + if (initialX.length != m) |
| 55 | + throw new IllegalArgumentException( |
| 56 | + "Initial guess length (" + initialX.length + ") does not match matrix columns (" + m + ")" |
| 57 | + ); |
| 58 | + |
| 59 | + // Aᵀb — written into r temporarily, then copied to Atb slot |
| 60 | + A.transposeMultiply(b, Atb); |
| 61 | + |
| 62 | + return conjugateGradientNormalEq(A, initialX, Atb, maxIter, tol, r, p, Ap, temp); |
| 63 | + } |
| 64 | + |
| 65 | + /** |
| 66 | + * CG on AᵀA x = Aᵀb without forming AᵀA explicitly. |
| 67 | + * All working arrays are passed in and reused across calls. |
| 68 | + */ |
| 69 | + private static double[] conjugateGradientNormalEq( |
| 70 | + Matrix A, double[] x, double[] Atb, int maxIter, double tol, |
| 71 | + double[] r, double[] p, double[] Ap, double[] temp) { |
| 72 | + |
| 73 | + int n = A.rows(); |
| 74 | + int m = A.cols(); |
| 75 | + |
| 76 | + // r = Atb - AᵀA·x (x is zero, so r = Atb on first call) |
| 77 | + multiplyAtA(A, x, temp, Ap); |
| 78 | + for (int i = 0; i < m; i++) { |
| 79 | + r[i] = Atb[i] - Ap[i]; |
| 80 | + p[i] = r[i]; |
| 81 | + } |
| 82 | + |
| 83 | + double rsold = dot(r, r, m); |
| 84 | + |
| 85 | + for (int k = 0; k < maxIter; k++) { |
| 86 | + multiplyAtA(A, p, temp, Ap); |
| 87 | + |
| 88 | + double dotPAp = dot(p, Ap, m); |
| 89 | + if (dotPAp == 0) break; |
| 90 | + double alpha = rsold / dotPAp; |
| 91 | + |
| 92 | + for (int i = 0; i < m; i++) x[i] += alpha * p[i]; |
| 93 | + for (int i = 0; i < m; i++) r[i] -= alpha * Ap[i]; |
| 94 | + |
| 95 | + double rsnew = dot(r, r, m); |
| 96 | + if (Math.sqrt(rsnew) < tol) break; |
| 97 | + |
| 98 | + double beta = rsnew / rsold; |
| 99 | + for (int i = 0; i < m; i++) p[i] = r[i] + beta * p[i]; |
| 100 | + rsold = rsnew; |
| 101 | + } |
| 102 | + return x; |
| 103 | + } |
| 104 | + |
| 105 | + private static void multiplyAtA(Matrix A, double[] p, double[] temp, double[] result) { |
| 106 | + A.multiply(p, temp); |
| 107 | + A.transposeMultiply(temp, result); |
| 108 | + } |
| 109 | + |
| 110 | + private static double dot(double[] a, double[] b, int len) { |
| 111 | + double sum = 0; |
| 112 | + for (int i = 0; i < len; i++) sum += a[i] * b[i]; |
| 113 | + return sum; |
| 114 | + } |
| 115 | +} |
| 116 | + |
0 commit comments