diff --git a/DivideAndConquer/StrassenMatrixMultiplication.java b/DivideAndConquer/StrassenMatrixMultiplication.java new file mode 100644 index 000000000..2608d5a3b --- /dev/null +++ b/DivideAndConquer/StrassenMatrixMultiplication.java @@ -0,0 +1,181 @@ +// Java Program to Implement Strassen Algorithm + +// Class Strassen matrix multiplication +public class StrassenMatrixMultiplication { + + // Method 1 + // Function to multiply matrices + public int[][] multiply(int[][] A, int[][] B) + { + int n = A.length; + + int[][] R = new int[n][n]; + + if (n == 1) + + R[0][0] = A[0][0] * B[0][0]; + + else { + // Dividing Matrix into parts + // by storing sub-parts to variables + int[][] A11 = new int[n / 2][n / 2]; + int[][] A12 = new int[n / 2][n / 2]; + int[][] A21 = new int[n / 2][n / 2]; + int[][] A22 = new int[n / 2][n / 2]; + int[][] B11 = new int[n / 2][n / 2]; + int[][] B12 = new int[n / 2][n / 2]; + int[][] B21 = new int[n / 2][n / 2]; + int[][] B22 = new int[n / 2][n / 2]; + + // Dividing matrix A into 4 parts + split(A, A11, 0, 0); + split(A, A12, 0, n / 2); + split(A, A21, n / 2, 0); + split(A, A22, n / 2, n / 2); + + // Dividing matrix B into 4 parts + split(B, B11, 0, 0); + split(B, B12, 0, n / 2); + split(B, B21, n / 2, 0); + split(B, B22, n / 2, n / 2); + + // Using Formulas as described in algorithm + + // M1:=(A1+A3)×(B1+B2) + int[][] M1 + = multiply(add(A11, A22), add(B11, B22)); + + // M2:=(A2+A4)×(B3+B4) + int[][] M2 = multiply(add(A21, A22), B11); + + // M3:=(A1−A4)×(B1+A4) + int[][] M3 = multiply(A11, sub(B12, B22)); + + // M4:=A1×(B2−B4) + int[][] M4 = multiply(A22, sub(B21, B11)); + + // M5:=(A3+A4)×(B1) + int[][] M5 = multiply(add(A11, A12), B22); + + // M6:=(A1+A2)×(B4) + int[][] M6 + = multiply(sub(A21, A11), add(B11, B12)); + + // M7:=A4×(B3−B1) + int[][] M7 + = multiply(sub(A12, A22), add(B21, B22)); + + // P:=M2+M3−M6−M7 + int[][] C11 = add(sub(add(M1, M4), M5), M7); + + // Q:=M4+M6 + int[][] C12 = add(M3, M5); + + // R:=M5+M7 + int[][] C21 = add(M2, M4); + + // S:=M1−M3−M4−M5 + int[][] C22 = add(sub(add(M1, M3), M2), M6); + + join(C11, R, 0, 0); + join(C12, R, 0, n / 2); + join(C21, R, n / 2, 0); + join(C22, R, n / 2, n / 2); + } + + return R; + } + + // Method 2 + // Function to subtract two matrices + public int[][] sub(int[][] A, int[][] B) + { + int n = A.length; + + int[][] C = new int[n][n]; + + for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) + C[i][j] = A[i][j] - B[i][j]; + + return C; + } + + // Method 3 + // Function to add two matrices + public int[][] add(int[][] A, int[][] B) + { + + int n = A.length; + + int[][] C = new int[n][n]; + + for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) + C[i][j] = A[i][j] + B[i][j]; + + return C; + } + + // Method 4 + // Function to split parent matrix + // into child matrices + public void split(int[][] P, int[][] C, int iB, int jB) + { + for (int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++) + for (int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++) + C[i1][j1] = P[i2][j2]; + } + + // Method 5 + // Function to join child matrices + // into (to) parent matrix + public void join(int[][] C, int[][] P, int iB, int jB) + + { + for (int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++) + for (int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++) + P[i2][j2] = C[i1][j1]; + } + + // Method 5 + // Main driver method + public static void main(String[] args) + { + System.out.println("Strassen Multiplication Algorithm Implementation For Matrix Multiplication :\n"); + + StrassenMatrixMultiplication s = new StrassenMatrixMultiplication(); + + // Size of matrix + // Considering size as 4 in order to illustrate + int N = 4; + + // Matrix A + // Custom input to matrix + int[][] A = { { 1, 2, 5, 4 }, + { 9, 3, 0, 6 }, + { 4, 6, 3, 1 }, + { 0, 2, 0, 6 } }; + + // Matrix B + // Custom input to matrix + int[][] B = { { 1, 0, 4, 1 }, + { 1, 2, 0, 2 }, + { 0, 3, 1, 3 }, + { 1, 8, 1, 2 } }; + + // Matrix C computations + + // Matrix C calling method to get Result + int[][] C = s.multiply(A, B); + + System.out.println("\nProduct of matrices A and B : "); + + // Print the output + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) + System.out.print(C[i][j] + " "); + System.out.println(); + } + } +}