mirror of
https://github.com/TheAlgorithms/Java.git
synced 2025-12-19 07:00:35 +08:00
feat: Add Chebyshev Iteration algorithm (#6963)
* feat: Add Chebyshev Iteration algorithm * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * update * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java
This commit is contained in:
181
src/main/java/com/thealgorithms/maths/ChebyshevIteration.java
Normal file
181
src/main/java/com/thealgorithms/maths/ChebyshevIteration.java
Normal file
@@ -0,0 +1,181 @@
|
||||
package com.thealgorithms.maths;
|
||||
|
||||
/**
|
||||
* In numerical analysis, Chebyshev iteration is an iterative method for solving
|
||||
* systems of linear equations Ax = b. It is designed for systems where the
|
||||
* matrix A is symmetric positive-definite (SPD).
|
||||
*
|
||||
* <p>
|
||||
* This method is a "polynomial acceleration" method, meaning it finds the
|
||||
* optimal polynomial to apply to the residual to accelerate convergence.
|
||||
*
|
||||
* <p>
|
||||
* It requires knowledge of the bounds of the eigenvalues of the matrix A:
|
||||
* m(A) (smallest eigenvalue) and M(A) (largest eigenvalue).
|
||||
*
|
||||
* <p>
|
||||
* Wikipedia: https://en.wikipedia.org/wiki/Chebyshev_iteration
|
||||
*
|
||||
* @author Mitrajit Ghorui(KeyKyrios)
|
||||
*/
|
||||
public final class ChebyshevIteration {
|
||||
|
||||
private ChebyshevIteration() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves the linear system Ax = b using the Chebyshev iteration method.
|
||||
*
|
||||
* <p>
|
||||
* NOTE: The matrix A *must* be symmetric positive-definite (SPD) for this
|
||||
* algorithm to converge.
|
||||
*
|
||||
* @param a The matrix A (must be square, SPD).
|
||||
* @param b The vector b.
|
||||
* @param x0 The initial guess vector.
|
||||
* @param minEigenvalue The smallest eigenvalue of A (m(A)).
|
||||
* @param maxEigenvalue The largest eigenvalue of A (M(A)).
|
||||
* @param maxIterations The maximum number of iterations to perform.
|
||||
* @param tolerance The desired tolerance for the residual norm.
|
||||
* @return The solution vector x.
|
||||
* @throws IllegalArgumentException if matrix/vector dimensions are
|
||||
* incompatible,
|
||||
* if maxIterations <= 0, or if eigenvalues are invalid (e.g., minEigenvalue
|
||||
* <= 0, maxEigenvalue <= minEigenvalue).
|
||||
*/
|
||||
public static double[] solve(double[][] a, double[] b, double[] x0, double minEigenvalue, double maxEigenvalue, int maxIterations, double tolerance) {
|
||||
validateInputs(a, b, x0, minEigenvalue, maxEigenvalue, maxIterations, tolerance);
|
||||
|
||||
int n = b.length;
|
||||
double[] x = x0.clone();
|
||||
double[] r = vectorSubtract(b, matrixVectorMultiply(a, x));
|
||||
double[] p = new double[n];
|
||||
|
||||
double d = (maxEigenvalue + minEigenvalue) / 2.0;
|
||||
double c = (maxEigenvalue - minEigenvalue) / 2.0;
|
||||
|
||||
double alpha = 0.0;
|
||||
double alphaPrev = 0.0;
|
||||
|
||||
for (int k = 0; k < maxIterations; k++) {
|
||||
double residualNorm = vectorNorm(r);
|
||||
if (residualNorm < tolerance) {
|
||||
return x; // Solution converged
|
||||
}
|
||||
|
||||
if (k == 0) {
|
||||
alpha = 1.0 / d;
|
||||
System.arraycopy(r, 0, p, 0, n); // p = r
|
||||
} else {
|
||||
double beta = c * alphaPrev / 2.0 * (c * alphaPrev / 2.0);
|
||||
alpha = 1.0 / (d - beta / alphaPrev);
|
||||
double[] pUpdate = scalarMultiply(beta / alphaPrev, p);
|
||||
p = vectorAdd(r, pUpdate); // p = r + (beta / alphaPrev) * p
|
||||
}
|
||||
|
||||
double[] xUpdate = scalarMultiply(alpha, p);
|
||||
x = vectorAdd(x, xUpdate); // x = x + alpha * p
|
||||
|
||||
// Recompute residual for accuracy
|
||||
r = vectorSubtract(b, matrixVectorMultiply(a, x));
|
||||
alphaPrev = alpha;
|
||||
}
|
||||
|
||||
return x; // Return best guess after maxIterations
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the inputs for the Chebyshev solver.
|
||||
*/
|
||||
private static void validateInputs(double[][] a, double[] b, double[] x0, double minEigenvalue, double maxEigenvalue, int maxIterations, double tolerance) {
|
||||
int n = a.length;
|
||||
if (n == 0) {
|
||||
throw new IllegalArgumentException("Matrix A cannot be empty.");
|
||||
}
|
||||
if (n != a[0].length) {
|
||||
throw new IllegalArgumentException("Matrix A must be square.");
|
||||
}
|
||||
if (n != b.length) {
|
||||
throw new IllegalArgumentException("Matrix A and vector b dimensions do not match.");
|
||||
}
|
||||
if (n != x0.length) {
|
||||
throw new IllegalArgumentException("Matrix A and vector x0 dimensions do not match.");
|
||||
}
|
||||
if (minEigenvalue <= 0) {
|
||||
throw new IllegalArgumentException("Smallest eigenvalue must be positive (matrix must be positive-definite).");
|
||||
}
|
||||
if (maxEigenvalue <= minEigenvalue) {
|
||||
throw new IllegalArgumentException("Max eigenvalue must be strictly greater than min eigenvalue.");
|
||||
}
|
||||
if (maxIterations <= 0) {
|
||||
throw new IllegalArgumentException("Max iterations must be positive.");
|
||||
}
|
||||
if (tolerance <= 0) {
|
||||
throw new IllegalArgumentException("Tolerance must be positive.");
|
||||
}
|
||||
}
|
||||
|
||||
// --- Vector/Matrix Helper Methods ---
|
||||
/**
|
||||
* Computes the product of a matrix A and a vector v (Av).
|
||||
*/
|
||||
private static double[] matrixVectorMultiply(double[][] a, double[] v) {
|
||||
int n = a.length;
|
||||
double[] result = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
double sum = 0;
|
||||
for (int j = 0; j < n; j++) {
|
||||
sum += a[i][j] * v[j];
|
||||
}
|
||||
result[i] = sum;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the subtraction of two vectors (v1 - v2).
|
||||
*/
|
||||
private static double[] vectorSubtract(double[] v1, double[] v2) {
|
||||
int n = v1.length;
|
||||
double[] result = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
result[i] = v1[i] - v2[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the addition of two vectors (v1 + v2).
|
||||
*/
|
||||
private static double[] vectorAdd(double[] v1, double[] v2) {
|
||||
int n = v1.length;
|
||||
double[] result = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
result[i] = v1[i] + v2[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the product of a scalar and a vector (s * v).
|
||||
*/
|
||||
private static double[] scalarMultiply(double scalar, double[] v) {
|
||||
int n = v.length;
|
||||
double[] result = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
result[i] = scalar * v[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the L2 norm (Euclidean norm) of a vector.
|
||||
*/
|
||||
private static double vectorNorm(double[] v) {
|
||||
double sumOfSquares = 0;
|
||||
for (double val : v) {
|
||||
sumOfSquares += val * val;
|
||||
}
|
||||
return Math.sqrt(sumOfSquares);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package com.thealgorithms.maths;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class ChebyshevIterationTest {
|
||||
|
||||
@Test
|
||||
public void testSolveSimple2x2Diagonal() {
|
||||
double[][] a = {{2, 0}, {0, 1}};
|
||||
double[] b = {2, 2};
|
||||
double[] x0 = {0, 0};
|
||||
double minEig = 1.0;
|
||||
double maxEig = 2.0;
|
||||
int maxIter = 50;
|
||||
double tol = 1e-9;
|
||||
double[] expected = {1.0, 2.0};
|
||||
|
||||
double[] result = ChebyshevIteration.solve(a, b, x0, minEig, maxEig, maxIter, tol);
|
||||
assertArrayEquals(expected, result, 1e-9);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSolve2x2Symmetric() {
|
||||
double[][] a = {{4, 1}, {1, 3}};
|
||||
double[] b = {1, 2};
|
||||
double[] x0 = {0, 0};
|
||||
double minEig = (7.0 - Math.sqrt(5.0)) / 2.0;
|
||||
double maxEig = (7.0 + Math.sqrt(5.0)) / 2.0;
|
||||
int maxIter = 100;
|
||||
double tol = 1e-10;
|
||||
double[] expected = {1.0 / 11.0, 7.0 / 11.0};
|
||||
|
||||
double[] result = ChebyshevIteration.solve(a, b, x0, minEig, maxEig, maxIter, tol);
|
||||
assertArrayEquals(expected, result, 1e-9);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAlreadyAtSolution() {
|
||||
double[][] a = {{2, 0}, {0, 1}};
|
||||
double[] b = {2, 2};
|
||||
double[] x0 = {1, 2};
|
||||
double minEig = 1.0;
|
||||
double maxEig = 2.0;
|
||||
int maxIter = 10;
|
||||
double tol = 1e-5;
|
||||
double[] expected = {1.0, 2.0};
|
||||
|
||||
double[] result = ChebyshevIteration.solve(a, b, x0, minEig, maxEig, maxIter, tol);
|
||||
assertArrayEquals(expected, result, 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMismatchedDimensionsAB() {
|
||||
double[][] a = {{1, 0}, {0, 1}};
|
||||
double[] b = {1};
|
||||
double[] x0 = {0, 0};
|
||||
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 10, 1e-5));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMismatchedDimensionsAX() {
|
||||
double[][] a = {{1, 0}, {0, 1}};
|
||||
double[] b = {1, 1};
|
||||
double[] x0 = {0};
|
||||
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 10, 1e-5));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNonSquareMatrix() {
|
||||
double[][] a = {{1, 0, 0}, {0, 1, 0}};
|
||||
double[] b = {1, 1};
|
||||
double[] x0 = {0, 0};
|
||||
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 10, 1e-5));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidEigenvalues() {
|
||||
double[][] a = {{1, 0}, {0, 1}};
|
||||
double[] b = {1, 1};
|
||||
double[] x0 = {0, 0};
|
||||
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 2, 1, 10, 1e-5));
|
||||
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 1, 10, 1e-5));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNonPositiveDefinite() {
|
||||
double[][] a = {{1, 0}, {0, 1}};
|
||||
double[] b = {1, 1};
|
||||
double[] x0 = {0, 0};
|
||||
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 0, 1, 10, 1e-5));
|
||||
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, -1, 1, 10, 1e-5));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidIterationCount() {
|
||||
double[][] a = {{1, 0}, {0, 1}};
|
||||
double[] b = {1, 1};
|
||||
double[] x0 = {0, 0};
|
||||
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 0, 1e-5));
|
||||
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, -1, 1e-5));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user