Add tests, remove main, enhance docs in MatrixChainMultiplication (#5658)

This commit is contained in:
Hardik Pawar
2024-10-10 22:46:04 +05:30
committed by GitHub
parent bd3b754eda
commit e728aa7d6f
3 changed files with 169 additions and 92 deletions

View File

@ -814,6 +814,7 @@
* [LongestIncreasingSubsequenceTests](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestIncreasingSubsequenceTests.java) * [LongestIncreasingSubsequenceTests](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestIncreasingSubsequenceTests.java)
* [LongestPalindromicSubstringTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestPalindromicSubstringTest.java) * [LongestPalindromicSubstringTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestPalindromicSubstringTest.java)
* [LongestValidParenthesesTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestValidParenthesesTest.java) * [LongestValidParenthesesTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestValidParenthesesTest.java)
* [MatrixChainMultiplicationTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MatrixChainMultiplicationTest.java)
* [MinimumPathSumTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MinimumPathSumTest.java) * [MinimumPathSumTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MinimumPathSumTest.java)
* [MinimumSumPartitionTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MinimumSumPartitionTest.java) * [MinimumSumPartitionTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MinimumSumPartitionTest.java)
* [OptimalJobSchedulingTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/OptimalJobSchedulingTest.java) * [OptimalJobSchedulingTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/OptimalJobSchedulingTest.java)

View File

@ -2,38 +2,32 @@ package com.thealgorithms.dynamicprogramming;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Scanner;
/**
* The MatrixChainMultiplication class provides functionality to compute the
* optimal way to multiply a sequence of matrices. The optimal multiplication
* order is determined using dynamic programming, which minimizes the total
* number of scalar multiplications required.
*/
public final class MatrixChainMultiplication { public final class MatrixChainMultiplication {
private MatrixChainMultiplication() { private MatrixChainMultiplication() {
} }
private static final Scanner SCANNER = new Scanner(System.in); // Matrices to store minimum multiplication costs and split points
private static final ArrayList<Matrix> MATRICES = new ArrayList<>();
private static int size;
private static int[][] m; private static int[][] m;
private static int[][] s; private static int[][] s;
private static int[] p; private static int[] p;
public static void main(String[] args) { /**
int count = 1; * Calculates the optimal order for multiplying a given list of matrices.
while (true) { *
String[] mSize = input("input size of matrix A(" + count + ") ( ex. 10 20 ) : "); * @param matrices an ArrayList of Matrix objects representing the matrices
int col = Integer.parseInt(mSize[0]); * to be multiplied.
if (col == 0) { * @return a Result object containing the matrices of minimum costs and
break; * optimal splits.
} */
int row = Integer.parseInt(mSize[1]); public static Result calculateMatrixChainOrder(ArrayList<Matrix> matrices) {
int size = matrices.size();
Matrix matrix = new Matrix(count, col, row);
MATRICES.add(matrix);
count++;
}
for (Matrix m : MATRICES) {
System.out.format("A(%d) = %2d x %2d%n", m.count(), m.col(), m.row());
}
size = MATRICES.size();
m = new int[size + 1][size + 1]; m = new int[size + 1][size + 1];
s = new int[size + 1][size + 1]; s = new int[size + 1][size + 1];
p = new int[size + 1]; p = new int[size + 1];
@ -44,51 +38,20 @@ public final class MatrixChainMultiplication {
} }
for (int i = 0; i < p.length; i++) { for (int i = 0; i < p.length; i++) {
p[i] = i == 0 ? MATRICES.get(i).col() : MATRICES.get(i - 1).row(); p[i] = i == 0 ? matrices.get(i).col() : matrices.get(i - 1).row();
} }
matrixChainOrder(); matrixChainOrder(size);
for (int i = 0; i < size; i++) { return new Result(m, s);
System.out.print("-------");
}
System.out.println();
printArray(m);
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();
printArray(s);
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();
System.out.println("Optimal solution : " + m[1][size]);
System.out.print("Optimal parens : ");
printOptimalParens(1, size);
} }
private static void printOptimalParens(int i, int j) { /**
if (i == j) { * A helper method that computes the minimum cost of multiplying
System.out.print("A" + i); * the matrices using dynamic programming.
} else { *
System.out.print("("); * @param size the number of matrices in the multiplication sequence.
printOptimalParens(i, s[i][j]); */
printOptimalParens(s[i][j] + 1, j); private static void matrixChainOrder(int size) {
System.out.print(")");
}
}
private static void printArray(int[][] array) {
for (int i = 1; i < size + 1; i++) {
for (int j = 1; j < size + 1; j++) {
System.out.printf("%7d", array[i][j]);
}
System.out.println();
}
}
private static void matrixChainOrder() {
for (int i = 1; i < size + 1; i++) { for (int i = 1; i < size + 1; i++) {
m[i][i] = 0; m[i][i] = 0;
} }
@ -109,33 +72,92 @@ public final class MatrixChainMultiplication {
} }
} }
private static String[] input(String string) { /**
System.out.print(string); * The Result class holds the results of the matrix chain multiplication
return (SCANNER.nextLine().split(" ")); * calculation, including the matrix of minimum costs and split points.
*/
public static class Result {
private final int[][] m;
private final int[][] s;
/**
* Constructs a Result object with the specified matrices of minimum
* costs and split points.
*
* @param m the matrix of minimum multiplication costs.
* @param s the matrix of optimal split points.
*/
public Result(int[][] m, int[][] s) {
this.m = m;
this.s = s;
} }
}
class Matrix { /**
* Returns the matrix of minimum multiplication costs.
*
* @return the matrix of minimum multiplication costs.
*/
public int[][] getM() {
return m;
}
/**
* Returns the matrix of optimal split points.
*
* @return the matrix of optimal split points.
*/
public int[][] getS() {
return s;
}
}
/**
* The Matrix class represents a matrix with its dimensions and count.
*/
public static class Matrix {
private final int count; private final int count;
private final int col; private final int col;
private final int row; private final int row;
Matrix(int count, int col, int row) { /**
* Constructs a Matrix object with the specified count, number of columns,
* and number of rows.
*
* @param count the identifier for the matrix.
* @param col the number of columns in the matrix.
* @param row the number of rows in the matrix.
*/
public Matrix(int count, int col, int row) {
this.count = count; this.count = count;
this.col = col; this.col = col;
this.row = row; this.row = row;
} }
int count() { /**
* Returns the identifier of the matrix.
*
* @return the identifier of the matrix.
*/
public int count() {
return count; return count;
} }
int col() { /**
* Returns the number of columns in the matrix.
*
* @return the number of columns in the matrix.
*/
public int col() {
return col; return col;
} }
int row() { /**
* Returns the number of rows in the matrix.
*
* @return the number of rows in the matrix.
*/
public int row() {
return row; return row;
} }
}
} }

View File

@ -0,0 +1,54 @@
package com.thealgorithms.dynamicprogramming;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.ArrayList;
import org.junit.jupiter.api.Test;
class MatrixChainMultiplicationTest {
@Test
void testMatrixCreation() {
MatrixChainMultiplication.Matrix matrix1 = new MatrixChainMultiplication.Matrix(1, 10, 20);
MatrixChainMultiplication.Matrix matrix2 = new MatrixChainMultiplication.Matrix(2, 20, 30);
assertEquals(1, matrix1.count());
assertEquals(10, matrix1.col());
assertEquals(20, matrix1.row());
assertEquals(2, matrix2.count());
assertEquals(20, matrix2.col());
assertEquals(30, matrix2.row());
}
@Test
void testMatrixChainOrder() {
// Create a list of matrices to be multiplied
ArrayList<MatrixChainMultiplication.Matrix> matrices = new ArrayList<>();
matrices.add(new MatrixChainMultiplication.Matrix(1, 10, 20)); // A(1) = 10 x 20
matrices.add(new MatrixChainMultiplication.Matrix(2, 20, 30)); // A(2) = 20 x 30
// Calculate matrix chain order
MatrixChainMultiplication.Result result = MatrixChainMultiplication.calculateMatrixChainOrder(matrices);
// Expected cost of multiplying A(1) and A(2)
int expectedCost = 6000; // The expected optimal cost of multiplying A(1)(10x20) and A(2)(20x30)
int actualCost = result.getM()[1][2];
assertEquals(expectedCost, actualCost);
}
@Test
void testOptimalParentheses() {
// Create a list of matrices to be multiplied
ArrayList<MatrixChainMultiplication.Matrix> matrices = new ArrayList<>();
matrices.add(new MatrixChainMultiplication.Matrix(1, 10, 20)); // A(1) = 10 x 20
matrices.add(new MatrixChainMultiplication.Matrix(2, 20, 30)); // A(2) = 20 x 30
// Calculate matrix chain order
MatrixChainMultiplication.Result result = MatrixChainMultiplication.calculateMatrixChainOrder(matrices);
// Check the optimal split for parentheses
assertEquals(1, result.getS()[1][2]); // s[1][2] should point to the optimal split
}
}