mirror of
https://github.com/TheAlgorithms/Java.git
synced 2025-12-19 07:00:35 +08:00
Add matrix multiplication with double[][] and unit tests (#6417)
* MatrixMultiplication.java created and updated. * Add necessary comment to MatrixMultiplication.java * Create MatrixMultiplicationTest.java * method for 2 by 2 matrix multiplication is created * Use assertMatrixEquals(), otherwise there can be error due to floating point arithmetic errors * assertMatrixEquals method created and updated * method created for 3by2 matrix multiply with 2by1 matrix * method created for null matrix multiplication * method for test matrix dimension error * method for test empty matrix input * testMultiply3by2and2by1 test case updated * Check for empty matrices part updated * Updated Unit test coverage * files updated * clean the code * clean the code * Updated files with google-java-format * Updated files * Updated files * Updated files * Updated files * Add reference links and complexities * Add test cases for 1by1 matrix and non-rectangular matrix * Add reference links and complexities --------- Co-authored-by: Deniz Altunkapan <93663085+DenizAltunkapan@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
package com.thealgorithms.matrix;
|
||||
|
||||
/**
|
||||
* This class provides a method to perform matrix multiplication.
|
||||
*
|
||||
* <p>Matrix multiplication takes two 2D arrays (matrices) as input and
|
||||
* produces their product, following the mathematical definition of
|
||||
* matrix multiplication.
|
||||
*
|
||||
* <p>For more details:
|
||||
* https://www.geeksforgeeks.org/java/java-program-to-multiply-two-matrices-of-any-size/
|
||||
* https://en.wikipedia.org/wiki/Matrix_multiplication
|
||||
*
|
||||
* <p>Time Complexity: O(n^3) – where n is the dimension of the matrices
|
||||
* (assuming square matrices for simplicity).
|
||||
*
|
||||
* <p>Space Complexity: O(n^2) – for storing the result matrix.
|
||||
*
|
||||
*
|
||||
* @author Nishitha Wihala Pitigala
|
||||
*
|
||||
*/
|
||||
|
||||
public final class MatrixMultiplication {
|
||||
private MatrixMultiplication() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiplies two matrices.
|
||||
*
|
||||
* @param matrixA the first matrix rowsA x colsA
|
||||
* @param matrixB the second matrix rowsB x colsB
|
||||
* @return the product of the two matrices rowsA x colsB
|
||||
* @throws IllegalArgumentException if the matrices cannot be multiplied
|
||||
*/
|
||||
public static double[][] multiply(double[][] matrixA, double[][] matrixB) {
|
||||
// Check the input matrices are not null
|
||||
if (matrixA == null || matrixB == null) {
|
||||
throw new IllegalArgumentException("Input matrices cannot be null");
|
||||
}
|
||||
|
||||
// Check for empty matrices
|
||||
if (matrixA.length == 0 || matrixB.length == 0 || matrixA[0].length == 0 || matrixB[0].length == 0) {
|
||||
throw new IllegalArgumentException("Input matrices must not be empty");
|
||||
}
|
||||
|
||||
// Validate the matrix dimensions
|
||||
if (matrixA[0].length != matrixB.length) {
|
||||
throw new IllegalArgumentException("Matrices cannot be multiplied: incompatible dimensions.");
|
||||
}
|
||||
|
||||
int rowsA = matrixA.length;
|
||||
int colsA = matrixA[0].length;
|
||||
int colsB = matrixB[0].length;
|
||||
|
||||
// Initialize the result matrix with zeros
|
||||
double[][] result = new double[rowsA][colsB];
|
||||
|
||||
// Perform matrix multiplication
|
||||
for (int i = 0; i < rowsA; i++) {
|
||||
for (int j = 0; j < colsB; j++) {
|
||||
for (int k = 0; k < colsA; k++) {
|
||||
result[i][j] += matrixA[i][k] * matrixB[k][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package com.thealgorithms.matrix;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class MatrixMultiplicationTest {
|
||||
|
||||
private static final double EPSILON = 1e-9; // for floating point comparison
|
||||
|
||||
@Test
|
||||
void testMultiply1by1() {
|
||||
double[][] matrixA = {{1.0}};
|
||||
double[][] matrixB = {{2.0}};
|
||||
double[][] expected = {{2.0}};
|
||||
|
||||
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
|
||||
assertMatrixEquals(expected, result);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiply2by2() {
|
||||
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
|
||||
double[][] matrixB = {{5.0, 6.0}, {7.0, 8.0}};
|
||||
double[][] expected = {{19.0, 22.0}, {43.0, 50.0}};
|
||||
|
||||
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
|
||||
assertMatrixEquals(expected, result); // Use custom method due to floating point issues
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiply3by2and2by1() {
|
||||
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}};
|
||||
double[][] matrixB = {{7.0}, {8.0}};
|
||||
double[][] expected = {{23.0}, {53.0}, {83.0}};
|
||||
|
||||
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
|
||||
assertMatrixEquals(expected, result);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiplyNonRectangularMatrices() {
|
||||
double[][] matrixA = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}};
|
||||
double[][] matrixB = {{7.0, 8.0}, {9.0, 10.0}, {11.0, 12.0}};
|
||||
double[][] expected = {{58.0, 64.0}, {139.0, 154.0}};
|
||||
|
||||
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
|
||||
assertMatrixEquals(expected, result);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNullMatrixA() {
|
||||
double[][] b = {{1, 2}, {3, 4}};
|
||||
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(null, b));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNullMatrixB() {
|
||||
double[][] a = {{1, 2}, {3, 4}};
|
||||
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, null));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiplyNull() {
|
||||
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
|
||||
double[][] matrixB = null;
|
||||
|
||||
Exception exception = assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(matrixA, matrixB));
|
||||
|
||||
String expectedMessage = "Input matrices cannot be null";
|
||||
String actualMessage = exception.getMessage();
|
||||
|
||||
assertTrue(actualMessage.contains(expectedMessage));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIncompatibleDimensions() {
|
||||
double[][] a = {{1.0, 2.0}};
|
||||
double[][] b = {{1.0, 2.0}};
|
||||
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEmptyMatrices() {
|
||||
double[][] a = new double[0][0];
|
||||
double[][] b = new double[0][0];
|
||||
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
|
||||
}
|
||||
|
||||
private void assertMatrixEquals(double[][] expected, double[][] actual) {
|
||||
assertEquals(expected.length, actual.length, "Row count mismatch");
|
||||
for (int i = 0; i < expected.length; i++) {
|
||||
assertEquals(expected[i].length, actual[i].length, "Column count mismatch at row " + i);
|
||||
for (int j = 0; j < expected[i].length; j++) {
|
||||
assertEquals(expected[i][j], actual[i][j], EPSILON, "Mismatch at (" + i + "," + j + ")");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user