Skip to content

Commit 3f8bb42

Browse files
feat(maths): add LU decomposition algorithm using Doolittle method
1 parent 4b8099c commit 3f8bb42

2 files changed

Lines changed: 263 additions & 0 deletions

File tree

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package com.thealgorithms.maths;
2+
3+
/**
4+
* LU Decomposition (Doolittle Algorithm)
5+
*
6+
* Decomposes a square matrix A into a lower triangular matrix L and an upper
7+
* triangular matrix U such that A = L * U. The diagonal of L contains all ones.
8+
*
9+
* This decomposition is useful for solving systems of linear equations, computing
10+
* matrix determinants, and finding matrix inverses efficiently.
11+
*
12+
* The algorithm proceeds by computing elements of L and U iteratively:
13+
* - For U: U[k][j] = A[k][j] - sum(L[k][s] * U[s][j]) for s < k
14+
* - For L: L[i][k] = (A[i][k] - sum(L[i][s] * U[s][k])) / U[k][k] for s < k
15+
*
16+
* Time Complexity: O(n^3) for decomposition, O(n^2) for solving a system.
17+
* Space Complexity: O(n^2) for storing L and U matrices.
18+
*
19+
* @see <a href="https://en.wikipedia.org/wiki/LU_decomposition">LU Decomposition</a>
20+
*/
21+
public final class LUDecomposition {
22+
23+
private LUDecomposition() {
24+
}
25+
26+
/**
27+
* Performs LU decomposition on a square matrix using the Doolittle algorithm.
28+
*
29+
* @param matrix a square matrix represented as a 2D array
30+
* @return a 2D array where the lower triangle (excluding diagonal) contains L
31+
* elements (with implicit 1s on the diagonal) and the upper triangle
32+
* (including diagonal) contains U elements
33+
* @throws IllegalArgumentException if the matrix is not square
34+
* @throws ArithmeticException if a zero pivot is encountered
35+
*/
36+
public static double[][] decompose(double[][] matrix) {
37+
int n = matrix.length;
38+
for (double[] row : matrix) {
39+
if (row.length != n) {
40+
throw new IllegalArgumentException("Matrix must be square.");
41+
}
42+
}
43+
44+
double[][] lu = new double[n][n];
45+
for (int i = 0; i < n; i++) {
46+
for (int j = 0; j < n; j++) {
47+
lu[i][j] = matrix[i][j];
48+
}
49+
}
50+
51+
for (int k = 0; k < n; k++) {
52+
// Compute U elements for row k
53+
for (int j = k; j < n; j++) {
54+
double sum = 0;
55+
for (int s = 0; s < k; s++) {
56+
sum += lu[k][s] * lu[s][j];
57+
}
58+
lu[k][j] -= sum;
59+
}
60+
61+
// Check for zero pivot
62+
if (lu[k][k] == 0) {
63+
throw new ArithmeticException("Zero pivot encountered. Matrix may be singular.");
64+
}
65+
66+
// Compute L elements for column k
67+
for (int i = k + 1; i < n; i++) {
68+
double sum = 0;
69+
for (int s = 0; s < k; s++) {
70+
sum += lu[i][s] * lu[s][k];
71+
}
72+
lu[i][k] = (lu[i][k] - sum) / lu[k][k];
73+
}
74+
}
75+
76+
return lu;
77+
}
78+
79+
/**
80+
* Extracts the lower triangular matrix L from the combined LU matrix.
81+
* The diagonal of L is set to 1 (Doolittle convention).
82+
*
83+
* @param lu the combined LU matrix from {@link #decompose(double[][])}
84+
* @return the lower triangular matrix L with 1s on the diagonal
85+
*/
86+
public static double[][] getLowerMatrix(double[][] lu) {
87+
int n = lu.length;
88+
double[][] lower = new double[n][n];
89+
for (int i = 0; i < n; i++) {
90+
lower[i][i] = 1.0;
91+
for (int j = 0; j < i; j++) {
92+
lower[i][j] = lu[i][j];
93+
}
94+
}
95+
return lower;
96+
}
97+
98+
/**
99+
* Extracts the upper triangular matrix U from the combined LU matrix.
100+
*
101+
* @param lu the combined LU matrix from {@link #decompose(double[][])}
102+
* @return the upper triangular matrix U
103+
*/
104+
public static double[][] getUpperMatrix(double[][] lu) {
105+
int n = lu.length;
106+
double[][] upper = new double[n][n];
107+
for (int i = 0; i < n; i++) {
108+
for (int j = i; j < n; j++) {
109+
upper[i][j] = lu[i][j];
110+
}
111+
}
112+
return upper;
113+
}
114+
115+
/**
116+
* Solves a system of linear equations Ax = b using LU decomposition.
117+
*
118+
* @param lu the combined LU matrix from {@link #decompose(double[][])}
119+
* @param b the right-hand side vector
120+
* @return the solution vector x
121+
*/
122+
public static double[] solve(double[][] lu, double[] b) {
123+
int n = lu.length;
124+
double[] y = new double[n];
125+
double[] x = new double[n];
126+
127+
// Forward substitution: solve Ly = b
128+
for (int i = 0; i < n; i++) {
129+
double sum = 0;
130+
for (int j = 0; j < i; j++) {
131+
sum += lu[i][j] * y[j];
132+
}
133+
y[i] = b[i] - sum;
134+
}
135+
136+
// Back substitution: solve Ux = y
137+
for (int i = n - 1; i >= 0; i--) {
138+
double sum = 0;
139+
for (int j = i + 1; j < n; j++) {
140+
sum += lu[i][j] * x[j];
141+
}
142+
x[i] = (y[i] - sum) / lu[i][i];
143+
}
144+
145+
return x;
146+
}
147+
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package com.thealgorithms.maths;
2+
3+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
6+
import org.junit.jupiter.api.Test;
7+
8+
public class LUDecompositionTest {
9+
10+
private static final double DELTA = 1e-9;
11+
12+
@Test
13+
void testDecomposeSimpleMatrix() {
14+
double[][] matrix = {{2, 1, 1}, {4, 3, 3}, {8, 7, 9}};
15+
double[][] lu = LUDecomposition.decompose(matrix);
16+
17+
double[][] lower = LUDecomposition.getLowerMatrix(lu);
18+
double[][] upper = LUDecomposition.getUpperMatrix(lu);
19+
20+
// Verify L has 1s on diagonal
21+
assertArrayEquals(new double[]{1, 1, 1}, new double[]{lower[0][0], lower[1][1], lower[2][2]}, DELTA);
22+
23+
// Verify L * U = A
24+
double[][] product = multiply(lower, upper);
25+
assertArrayEquals(new double[]{2, 1, 1, 4, 3, 3, 8, 7, 9}, flatten(product), DELTA);
26+
}
27+
28+
@Test
29+
void testDecomposeTwoByTwo() {
30+
double[][] matrix = {{1, 2}, {3, 4}};
31+
double[][] lu = LUDecomposition.decompose(matrix);
32+
33+
double[][] lower = LUDecomposition.getLowerMatrix(lu);
34+
double[][] upper = LUDecomposition.getUpperMatrix(lu);
35+
36+
double[][] product = multiply(lower, upper);
37+
assertArrayEquals(new double[]{1, 2, 3, 4}, flatten(product), DELTA);
38+
}
39+
40+
@Test
41+
void testDecomposeIdentityMatrix() {
42+
double[][] matrix = {{1, 0}, {0, 1}};
43+
double[][] lu = LUDecomposition.decompose(matrix);
44+
45+
double[][] lower = LUDecomposition.getLowerMatrix(lu);
46+
double[][] upper = LUDecomposition.getUpperMatrix(lu);
47+
48+
// For identity matrix, L = I and U = I
49+
assertArrayEquals(new double[]{1, 0, 0, 1}, flatten(lower), DELTA);
50+
assertArrayEquals(new double[]{1, 0, 0, 1}, flatten(upper), DELTA);
51+
}
52+
53+
@Test
54+
void testDecomposeNonSquareMatrixThrows() {
55+
double[][] matrix = {{1, 2, 3}, {4, 5, 6}};
56+
assertThrows(IllegalArgumentException.class, () -> LUDecomposition.decompose(matrix));
57+
}
58+
59+
@Test
60+
void testDecomposeSingularMatrixThrows() {
61+
double[][] matrix = {{0, 1}, {1, 0}};
62+
assertThrows(ArithmeticException.class, () -> LUDecomposition.decompose(matrix));
63+
}
64+
65+
@Test
66+
void testSolveLinearSystem() {
67+
// 2x + y + z = 8
68+
// 4x + 3y + 3z = 20
69+
// 8x + 7y + 9z = 46
70+
// Solution: x=1, y=3, z=3
71+
double[][] matrix = {{2, 1, 1}, {4, 3, 3}, {8, 7, 9}};
72+
double[] b = {8, 20, 46};
73+
double[][] lu = LUDecomposition.decompose(matrix);
74+
double[] solution = LUDecomposition.solve(lu, b);
75+
76+
assertArrayEquals(new double[]{1, 3, 3}, solution, DELTA);
77+
}
78+
79+
@Test
80+
void testSolveTwoByTwoSystem() {
81+
// 2x + y = 5
82+
// x + 3y = 7
83+
// Solution: x=1.6, y=1.8
84+
double[][] matrix = {{2, 1}, {1, 3}};
85+
double[] b = {5, 7};
86+
double[][] lu = LUDecomposition.decompose(matrix);
87+
double[] solution = LUDecomposition.solve(lu, b);
88+
89+
assertArrayEquals(new double[]{1.6, 1.8}, solution, DELTA);
90+
}
91+
92+
private static double[][] multiply(double[][] a, double[][] b) {
93+
int n = a.length;
94+
double[][] result = new double[n][n];
95+
for (int i = 0; i < n; i++) {
96+
for (int j = 0; j < n; j++) {
97+
for (int k = 0; k < n; k++) {
98+
result[i][j] += a[i][k] * b[k][j];
99+
}
100+
}
101+
}
102+
return result;
103+
}
104+
105+
private static double[] flatten(double[][] matrix) {
106+
int n = matrix.length;
107+
double[] result = new double[n * n];
108+
int idx = 0;
109+
for (double[] row : matrix) {
110+
for (double val : row) {
111+
result[idx++] = val;
112+
}
113+
}
114+
return result;
115+
}
116+
}

0 commit comments

Comments
 (0)