Add matrix multiplication with double[][] and unit tests (#6417)

* MatrixMultiplication.java created and updated.

* Add necessary comment to MatrixMultiplication.java

* Create MatrixMultiplicationTest.java

* method for 2 by 2 matrix multiplication is created

* Use assertMatrixEquals(), otherwise there can be error due to floating point arithmetic errors

* assertMatrixEquals method created and updated

* method created for 3by2 matrix multiply with 2by1 matrix

* method created for null matrix multiplication

* method for test matrix dimension error

* method for test empty matrix input

* testMultiply3by2and2by1 test case updated

* Check for empty matrices part updated

* Updated Unit test coverage

* files updated

* clean the code

* clean the code

* Updated files with google-java-format

* Updated files

* Updated files

* Updated files

* Updated files

* Add reference links and complexities

* Add test cases for 1by1 matrix and non-rectangular matrix

* Add reference links and complexities

---------

Co-authored-by: Deniz Altunkapan <93663085+DenizAltunkapan@users.noreply.github.com>
This commit is contained in:
Nishitha Wihala
2025-07-21 21:37:13 +05:30
committed by GitHub
parent 2722b0ecc9
commit 2dfad7ef8f
2 changed files with 170 additions and 0 deletions

View File

@@ -0,0 +1,69 @@
package com.thealgorithms.matrix;
/**
* This class provides a method to perform matrix multiplication.
*
* <p>Matrix multiplication takes two 2D arrays (matrices) as input and
* produces their product, following the mathematical definition of
* matrix multiplication.
*
* <p>For more details:
* https://www.geeksforgeeks.org/java/java-program-to-multiply-two-matrices-of-any-size/
* https://en.wikipedia.org/wiki/Matrix_multiplication
*
* <p>Time Complexity: O(n^3) where n is the dimension of the matrices
* (assuming square matrices for simplicity).
*
* <p>Space Complexity: O(n^2) for storing the result matrix.
*
*
* @author Nishitha Wihala Pitigala
*
*/
public final class MatrixMultiplication {
private MatrixMultiplication() {
}
/**
* Multiplies two matrices.
*
* @param matrixA the first matrix rowsA x colsA
* @param matrixB the second matrix rowsB x colsB
* @return the product of the two matrices rowsA x colsB
* @throws IllegalArgumentException if the matrices cannot be multiplied
*/
public static double[][] multiply(double[][] matrixA, double[][] matrixB) {
// Check the input matrices are not null
if (matrixA == null || matrixB == null) {
throw new IllegalArgumentException("Input matrices cannot be null");
}
// Check for empty matrices
if (matrixA.length == 0 || matrixB.length == 0 || matrixA[0].length == 0 || matrixB[0].length == 0) {
throw new IllegalArgumentException("Input matrices must not be empty");
}
// Validate the matrix dimensions
if (matrixA[0].length != matrixB.length) {
throw new IllegalArgumentException("Matrices cannot be multiplied: incompatible dimensions.");
}
int rowsA = matrixA.length;
int colsA = matrixA[0].length;
int colsB = matrixB[0].length;
// Initialize the result matrix with zeros
double[][] result = new double[rowsA][colsB];
// Perform matrix multiplication
for (int i = 0; i < rowsA; i++) {
for (int j = 0; j < colsB; j++) {
for (int k = 0; k < colsA; k++) {
result[i][j] += matrixA[i][k] * matrixB[k][j];
}
}
}
return result;
}
}

View File

@@ -0,0 +1,101 @@
package com.thealgorithms.matrix;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.Test;
public class MatrixMultiplicationTest {
private static final double EPSILON = 1e-9; // for floating point comparison
@Test
void testMultiply1by1() {
double[][] matrixA = {{1.0}};
double[][] matrixB = {{2.0}};
double[][] expected = {{2.0}};
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
assertMatrixEquals(expected, result);
}
@Test
void testMultiply2by2() {
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
double[][] matrixB = {{5.0, 6.0}, {7.0, 8.0}};
double[][] expected = {{19.0, 22.0}, {43.0, 50.0}};
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
assertMatrixEquals(expected, result); // Use custom method due to floating point issues
}
@Test
void testMultiply3by2and2by1() {
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}};
double[][] matrixB = {{7.0}, {8.0}};
double[][] expected = {{23.0}, {53.0}, {83.0}};
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
assertMatrixEquals(expected, result);
}
@Test
void testMultiplyNonRectangularMatrices() {
double[][] matrixA = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}};
double[][] matrixB = {{7.0, 8.0}, {9.0, 10.0}, {11.0, 12.0}};
double[][] expected = {{58.0, 64.0}, {139.0, 154.0}};
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
assertMatrixEquals(expected, result);
}
@Test
void testNullMatrixA() {
double[][] b = {{1, 2}, {3, 4}};
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(null, b));
}
@Test
void testNullMatrixB() {
double[][] a = {{1, 2}, {3, 4}};
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, null));
}
@Test
void testMultiplyNull() {
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
double[][] matrixB = null;
Exception exception = assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(matrixA, matrixB));
String expectedMessage = "Input matrices cannot be null";
String actualMessage = exception.getMessage();
assertTrue(actualMessage.contains(expectedMessage));
}
@Test
void testIncompatibleDimensions() {
double[][] a = {{1.0, 2.0}};
double[][] b = {{1.0, 2.0}};
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
}
@Test
void testEmptyMatrices() {
double[][] a = new double[0][0];
double[][] b = new double[0][0];
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
}
private void assertMatrixEquals(double[][] expected, double[][] actual) {
assertEquals(expected.length, actual.length, "Row count mismatch");
for (int i = 0; i < expected.length; i++) {
assertEquals(expected[i].length, actual[i].length, "Column count mismatch at row " + i);
for (int j = 0; j < expected[i].length; j++) {
assertEquals(expected[i][j], actual[i][j], EPSILON, "Mismatch at (" + i + "," + j + ")");
}
}
}
}