Add tests, remove main, enhance docs in MatrixChainMultiplication (#5658)

This commit is contained in:
Hardik Pawar
2024-10-10 22:46:04 +05:30
committed by GitHub
parent bd3b754eda
commit e728aa7d6f
3 changed files with 169 additions and 92 deletions

View File

@ -2,38 +2,32 @@ package com.thealgorithms.dynamicprogramming;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Scanner;
/**
* The MatrixChainMultiplication class provides functionality to compute the
* optimal way to multiply a sequence of matrices. The optimal multiplication
* order is determined using dynamic programming, which minimizes the total
* number of scalar multiplications required.
*/
public final class MatrixChainMultiplication {
private MatrixChainMultiplication() {
}
private static final Scanner SCANNER = new Scanner(System.in);
private static final ArrayList<Matrix> MATRICES = new ArrayList<>();
private static int size;
// Matrices to store minimum multiplication costs and split points
private static int[][] m;
private static int[][] s;
private static int[] p;
public static void main(String[] args) {
int count = 1;
while (true) {
String[] mSize = input("input size of matrix A(" + count + ") ( ex. 10 20 ) : ");
int col = Integer.parseInt(mSize[0]);
if (col == 0) {
break;
}
int row = Integer.parseInt(mSize[1]);
Matrix matrix = new Matrix(count, col, row);
MATRICES.add(matrix);
count++;
}
for (Matrix m : MATRICES) {
System.out.format("A(%d) = %2d x %2d%n", m.count(), m.col(), m.row());
}
size = MATRICES.size();
/**
* Calculates the optimal order for multiplying a given list of matrices.
*
* @param matrices an ArrayList of Matrix objects representing the matrices
* to be multiplied.
* @return a Result object containing the matrices of minimum costs and
* optimal splits.
*/
public static Result calculateMatrixChainOrder(ArrayList<Matrix> matrices) {
int size = matrices.size();
m = new int[size + 1][size + 1];
s = new int[size + 1][size + 1];
p = new int[size + 1];
@ -44,51 +38,20 @@ public final class MatrixChainMultiplication {
}
for (int i = 0; i < p.length; i++) {
p[i] = i == 0 ? MATRICES.get(i).col() : MATRICES.get(i - 1).row();
p[i] = i == 0 ? matrices.get(i).col() : matrices.get(i - 1).row();
}
matrixChainOrder();
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();
printArray(m);
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();
printArray(s);
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();
System.out.println("Optimal solution : " + m[1][size]);
System.out.print("Optimal parens : ");
printOptimalParens(1, size);
matrixChainOrder(size);
return new Result(m, s);
}
private static void printOptimalParens(int i, int j) {
if (i == j) {
System.out.print("A" + i);
} else {
System.out.print("(");
printOptimalParens(i, s[i][j]);
printOptimalParens(s[i][j] + 1, j);
System.out.print(")");
}
}
private static void printArray(int[][] array) {
for (int i = 1; i < size + 1; i++) {
for (int j = 1; j < size + 1; j++) {
System.out.printf("%7d", array[i][j]);
}
System.out.println();
}
}
private static void matrixChainOrder() {
/**
* A helper method that computes the minimum cost of multiplying
* the matrices using dynamic programming.
*
* @param size the number of matrices in the multiplication sequence.
*/
private static void matrixChainOrder(int size) {
for (int i = 1; i < size + 1; i++) {
m[i][i] = 0;
}
@ -109,33 +72,92 @@ public final class MatrixChainMultiplication {
}
}
private static String[] input(String string) {
System.out.print(string);
return (SCANNER.nextLine().split(" "));
}
}
class Matrix {
private final int count;
private final int col;
private final int row;
Matrix(int count, int col, int row) {
this.count = count;
this.col = col;
this.row = row;
}
int count() {
return count;
}
int col() {
return col;
}
int row() {
return row;
/**
* The Result class holds the results of the matrix chain multiplication
* calculation, including the matrix of minimum costs and split points.
*/
public static class Result {
private final int[][] m;
private final int[][] s;
/**
* Constructs a Result object with the specified matrices of minimum
* costs and split points.
*
* @param m the matrix of minimum multiplication costs.
* @param s the matrix of optimal split points.
*/
public Result(int[][] m, int[][] s) {
this.m = m;
this.s = s;
}
/**
* Returns the matrix of minimum multiplication costs.
*
* @return the matrix of minimum multiplication costs.
*/
public int[][] getM() {
return m;
}
/**
* Returns the matrix of optimal split points.
*
* @return the matrix of optimal split points.
*/
public int[][] getS() {
return s;
}
}
/**
* The Matrix class represents a matrix with its dimensions and count.
*/
public static class Matrix {
private final int count;
private final int col;
private final int row;
/**
* Constructs a Matrix object with the specified count, number of columns,
* and number of rows.
*
* @param count the identifier for the matrix.
* @param col the number of columns in the matrix.
* @param row the number of rows in the matrix.
*/
public Matrix(int count, int col, int row) {
this.count = count;
this.col = col;
this.row = row;
}
/**
* Returns the identifier of the matrix.
*
* @return the identifier of the matrix.
*/
public int count() {
return count;
}
/**
* Returns the number of columns in the matrix.
*
* @return the number of columns in the matrix.
*/
public int col() {
return col;
}
/**
* Returns the number of rows in the matrix.
*
* @return the number of rows in the matrix.
*/
public int row() {
return row;
}
}
}