算法之矩阵连乘问题

一.问题描述

    给定n个矩阵{A1,A2,……,An},其中Ai与Ai+1是可乘的,i=1,2,……,n-1。

   例如:

     计算三个矩阵连乘{A1,A2,A3};维数分别为10*100 , 100*5 , 5*50

     按此顺序计算需要的次数((A1*A2)*A3):10X100X5+10X5X50=7500次

     按此顺序计算需要的次数(A1*(A2*A3)):10X5X50+10X100X50=75000次

     所以要解决的问题是:如何确定矩阵连乘积A1A2,……An的计算次序,使得按此计算次序计算矩阵连乘积需要的数乘次数达到最小化。

二.问题分析

   由于矩阵乘法满足结合律,所以计算矩阵连乘的连乘积可以与许多不同的计算计算次序,这种计算次序可以用加括号的方式来确定。若一个矩阵连乘积的计算次序完全确定,也就是说连乘积已完全加括号,那么可以依此次序反复调用2个矩阵相乘的标准算法计算出矩阵连乘积。

   完全加括号的矩阵连乘积可递归地定义为:

   (1).单个矩阵是完全加括号的;

  (2).矩阵连乘积A是完全加括号的,则A可以表示为2个完全加括号的矩阵连乘积B和C的乘积并加括号,及A=(BC);

    举个例子,矩阵连乘积A1A2A3A4A5,可以有5种不同的完全加括号方式:

       (A1(A2(A3A4))),(A1((A2A3)A4)),((A1A2)(A3A4)),((A1(A2A3))A4),(((A1A2)A3)A4)

    每一种完全加括号的方式对应一种矩阵连乘积的计算次序,而矩阵连乘积的计算次序与其计算量有密切的关系,即与矩阵的行和列有关。

    补充一下数学知识,矩阵A与矩阵B可乘的条件为矩阵A的列数等于矩阵B的行数,例如,若A是一个p*q的矩阵,B是一个q*r的矩阵,则其乘积C=AB是一个p*r的矩阵。

三.问题求解

1.分析最优子结构:计算一组矩阵连乘,可以将该组矩阵从中间某一位置断开,分别计算左右两个小组矩阵连乘,以此类推,直到只剩下两个矩阵相乘,则只有一种次序。与分治法不同的是,计算其中某一组矩阵连乘,会用到其下方计算得到的解,而且分析可得,当底层小组连乘时是采用最佳次序,那么以此得到的更高层的解也是最佳次序,即问题的最优解可以由子问题的最优解来体现。

2.建立递归关系:如下图所示公式

  

3.计算最优值:自底向上进行计算,并将得到的子问题的解保存下来,之后用到的时候直接查找使用。

#include <iostream>
using namespace std;
#define N 6

//各矩阵维数数组P[n+1],数组长度n,子问题解数组m,最优断开位置数组s
void MatrixChain(int p[],int n,int m[N][N],int s[N][N])
{
    for(int i=0;i<n;i++) {m[i][i]=0;s[i][i]=0;} //单元素矩阵无需相乘
    for(int r=1;r<n;r++)
    {
        for(int i=0;i<n-r;i++)
        {
            int j=i+r;
            m[i][j]=m[i][i]+m[i+1][j]+p[i]*p[i+1]*p[j+1];
            s[i][j]=i;
            for(int k=i+1;k<j;k++)
            {
                int t=m[i][k]+m[k+1][j]+p[i]*p[k+1]*p[j+1];
                if(t<m[i][j]){
                    m[i][j]=t;
                    s[i][j]=k;
                }
            }
        }
    }
}

//计算最优解
void Traceback(int i,int j,int s[N][N])
{
    if(i==j)return;
    Traceback(i,s[i][j],s);
    Traceback(s[i][j]+1,j,s);
    cout<<"Multiply A "<<i<<","<<s[i][j];
    cout<<" and A "<<s[i][j]+1<<","<<j<<endl;
}


int main()
{
    int p[]={30,35,15,5,10,20,25};
    int m[N][N],s[N][N];
    MatrixChain(p,N,m,s);
    for(int i=0;i<N;i++)
    {
        for(int j=i;j<N;j++)
            cout<<m[i][j]<<" ";
        cout<<endl;
    }
    for(int i=0;i<N;i++)
    {
        for(int j=i;j<N;j++)
            cout<<s[i][j]<<" ";
        cout<<endl;
    }
    Traceback(0,5,s);
}

四.复杂度分析:最坏时间复杂度O(n^3);

猜你喜欢

转载自blog.csdn.net/qq_35503380/article/details/80380745