Summary: In this tutorial, we will learn what the Matrix Chain Multiplication problem is and how to solve Matrix Chain Multiplication using Dynamic Programming in Java.

What is Matrix Chain Multiplication?

Given a sequence of matrices, the goal is to find the most efficient way to multiply these matrices. The problem is not actually to perform the multiplications, but merely to decide the sequence of the matrix multiplications involved.

Wikipedia

In Matrix Chain Multiplication Problem, we are given some matrices and are asked to multiply them in such a way that the total number of multiplication is minimum.

Let’s breakdown the problem.

Two matrices can only be multiplied if the number of columns of the first matrix is equal to the number of rows of the second one.

Matrix multiplication

And the number of multiplications required to multiply two matrices is the product of their order.

Number of multiplication in multiplying two matrices

In the case of more than two matrices, the total number of multiplication depends on the sequence of multiplication. E.g. see the diagram below.

Matrix Chain Multiplication

Therefore, we need to find a sequence of multiplication of matrices which results into minimum number of multiplication.

Matrix Chain Multiplication Solution using Dynamic Programming

Matrix chain multiplication problem can be easily solved using dynamic programming because it is an optimization problem, where we need to find the most efficient sequence of multiplying the matrices.

Recommended: If you don’t know what is dynamic programming?

There is no doubt that we have to examine every possible sequence or parenthesization. Therefore the matrix chain problem with ‘n‘ matrices can be solved in 2nCn/(n+1) ways.

Using dynamic programming the process can be made easy and more efficient.

For example, consider the following sequences for a set of matrices.

matrix chain multiplication dynamic programming

Notice that multiplication of matrix A with matrix B i.e. (A.B) is being repeated in two sequences.

If we could reuse the previous multiplication result of A.B in the next sequence, our algorithm will become faster.

For this, we have to store the solution of subproblems like this into a 2D array i.e. memoize, so that we can use it later easily.

The following picture illustrate the solution for the matrix chain multiplication using dynamic programming :

matrix chain multiplication dynamic programming 2
matrix chain multiplication dynamic programming 3
matrix chain multiplication dynamic programming 4

Here the matrix index represents the multiplication sequence of a set of matrixes and the corresponding value holds the required minimum multiplications.

Example: (1,3) represent the multiplication of sequence from A2 to A4 i.e. A2xA3xA4

In the calculation of the next bigger sequence, we are using the values of previous multiplications stored in the 2D array.

Let’s implement the same in program code.

Matrix Chain Multiplication in Java

package MatrixChainMultiplication;
 
public class Matrix {
    int row;
    int col;
 
    public Matrix(int row, int col) {
        this.row = row;
        this.col = col;
    }
}
 
//Algorithm class
public class MatrixChain {
    int numberOfMatrices;
    Matrix matrices[];
    int matrixMulCount[][];
 
    public MatrixChain(Matrix[] matrices) {
        this.matrices = matrices;
        this.numberOfMatrices = matrices.length;
        matrixMulCount = new int[this.numberOfMatrices][this.numberOfMatrices];
    }
 
    //Solving matrix chain multiplication using dynamic programming
    public int solve(){
 
        for(int k=1; k<numberOfMatrices; k++){
 
            for(int i=0; i<numberOfMatrices; i++){
 
                if(i+k >= numberOfMatrices) break;
 
                //matrixMulCount[i][i+k]
 
                matrixMulCount[i][i+k] = matrixMulCount[i][i]
                                        + matrixMulCount[i+1][i+k]
                                        + matrices[i].row * matrices[i].col * matrices[i+k].col;
 
                for(int j=i+1; j<i+k; j++){
 
                    matrixMulCount[i][i+k]= Integer.min(matrixMulCount[i][i+k],
                            matrixMulCount[i][j] + matrixMulCount[j+1][i+k]
                                    +( j+1 == i+k ? ( matrices[i].row * matrices[i+k].row * matrices[i+k].col): (matrices[i].row * matrices[j].col * matrices[i+k].col)));
 
                }
            }
        }
        return matrixMulCount[0][numberOfMatrices-1];
    }
}
 
//Class containing 'main' method
public class App {
 
    public static void main(String args[]){
 
        //Creating A1, A2, A3 & A4 
        Matrix matrices[] = {new Matrix(5,4),
                new Matrix(4,6),
                new Matrix(6,2),
                new Matrix(2,7)};
 
        MatrixChain matrixChain = new MatrixChain(matrices);
        int multiplications = matrixChain.solve();
 
        //Printing 2D array (memoize table)
        for(int i=0; i<matrices.length; i++){
 
            for(int j=0; j<matrices.length; j++){
 
                System.out.print(matrixChain.matrixMulCount[i][j]+ "\t");
            }
 
            System.out.println();
 
        }
 
        System.out.println("Minimum multiplications required: "+multiplications);
    }
}

Output:

matrix chain multiplication solution

In this tutorial, we learned what the Matrix chain multiplication problem is and how to solve matrix chain multiplication problem using dynamic programming.

This Post Has One Comment

  1. Joe Petricone

    This actually answered my downside, thank you!

Leave a Reply to Joe Petricone Cancel reply