Warning: Undefined array key "HTTP_ACCEPT_LANGUAGE" in /www/wwwroot/blog/wp-content/plugins/UEditor-KityFormula-for-wordpress/main.php on line 13
【算法设计与分析】Strassen矩阵乘法(分治、剪枝)[附Python源码] – Machine World

【背景】

        矩阵乘法是线性代数中最常见的问题之一,它不仅在数值计算中具有广泛的应用,还是现代机器学习技术中必不可少的基石。

【定义】

\( 设A、B是两个n \times n 矩阵, 他们的乘积AB同样是一个n\times n 矩阵\)

即:

\(A_{n\times n} B_{n \times n} = C_{n \times n}\)

\( A和B的乘积矩阵C中各元素C_{ij}定义为:\)

\(\begin{align}C_{ij} = \sum_{k=1}^n A_{ik}B_{kj}\end{align}\)

【分析】

        若按照上述提及的公式一次对矩阵A、B进行乘积运算。计算C中每一个元素\(C_{ij}\) 需做n次乘法和n-1次加法运算,因此,欲计算出C中每一个元素的时间复杂度为\(O(n^3)\)

其源码如下:

def traditional(matrix1, matrix2):
    matrix3 = []
    for i in range(0, len(matrix1)):
        temp = []
        for j in range(0, len(matrix2)):
            t = 0
            for k in range(0, len(matrix1)):
                t += matrix1[i][k] * matrix2[k][j]
            temp.append(t)
        matrix3.append(temp)
    return matrix3

【算法引出-分治法】

        根据n阶(此处为了方便叙述,我们假设n是2的幂)矩阵的相关特性,我们可以将每一块矩阵都分为4个大小相等的子矩阵,每一个子矩阵都是n/2×n/2的方阵。于是我们将方程C= AB重写为下述形式:

\(\begin{align}\begin{bmatrix}C_{11} & C_{12} \\ C_{21} & C_{22}\end{bmatrix} = \begin{bmatrix}A_{11} & A_{12} \\ A_{21} & A_{22}\end{bmatrix} \begin{bmatrix}B_{11} & B_{12} \\ B_{21} & B_{22}\end{bmatrix}\end{align}\)

由此可得:

\( \begin{align}&C_{11} = A_{11}B_{11} + A_{12}B_{21} \\&C_{12} = A_{11}B_{12} + A_{12}B_{22}\\&C_{21} = A_{21}B_{11} + A_{22}B_{21}\\&C_{22} = A_{21}B_{12} + A_{22}B_{22}\end{align}\)

其分治递推式如下:

\(\begin{align}T(n) = \begin{cases} O(1) & n =2 \\ 8T(n/2) + O(n^2) & n >2\end{cases}\end{align}\)

利用扩展递归求解得出:

\( T(n) = O(n^3)\)

初次分治,得出的结果与传统公式求解的时间复杂度并没有改变,即这样的分治是徒劳的。

【Strassen矩阵乘法-分治、剪枝】

        Strassen算法的核心思想是令递归树稍微不那么茂盛一点儿, 即只递归进行7次而不是8次n/2×n/2 矩阵的乘法。减少一次矩阵乘法带来的代价可能是额外几次n/2×n/2矩阵的加法,但只是常数次 。

算法描述如下:

先按照先前的分治思想中矩阵分解的方法将A,B,C进行分解

创建如下7个\( n/2 \times n/2\)的矩阵\(M_1, M_2,M_3,\dots, M_7\):

\(\begin{align}&M_1 = A_{11}(B_12 – B_{22}) \\ &M_2 = (A_{11} + A_{12})B_{22} \\& M_3 = (A_{21} + A_{22})B_{11} \\& M_4 = A_{22}(B_{21} – B_{11})\\ &M_5 = (A_{11} + A_{22})(B_{11}+B_{22}) \\ & M_6=(A_{12} – A_{22})(B_{21}+B_{22})\\&M_7 = (A_{11} -A_{21})(B_{11} + B_{12})\end{align}\)

做完这7次乘法后,再做若干次加减法就可以得到\(C_{11},C_{12}, C_{21}, C_{22}\),他们的计算公式如下:

\(\begin{align}&C_{11} = M_5 + M_4 – M_2 + M_6 \\ &C_{12} = M_1 + M2 \\ &C_{21} = M3+M4\\&C_{22} = M_5+M_1-M_3-M_7\end{align}\)

其分治递推式如下:

\(\begin{align} T(n) = \begin{cases} O(1) & n =2 \\ 7T(n/2) + O(n^2) & n > 2\end{cases}\end{align}\)

求出:

\( T(n) = O(n^{log7} )\approx O(n^{2.81})\)

其Python源代码如下:

# -*- coding:utf-8 -*-

def mergeMatrix(A11, A12, A13, A14):
    n = len(A11)
    for i in range(0, n):
        A11[i].extend(A12[i])
        A13[i].extend(A14[i])
    for i in range(0,n):
        A11.append(A13[i])
    return A11
def division(matrix):
    A11 = []
    A12 = []
    A21 = []
    A22 = []
    half_size = int(len(matrix) / 2)
    for i in range(0, half_size):
        A11.append(matrix[i][:half_size])
        A12.append(matrix[i][half_size:])
    for j in range(half_size, len(matrix)):
        A21.append(matrix[j][:half_size])
        A22.append(matrix[j][half_size:])
    return A11, A12, A21, A22


class Strassen:
    def add(self, m1, m2, size):
        matrix = []
        for i in range(0, size):
            temp = []
            for j in range(0, size):
                temp.append(m1[i][j] + m2[i][j])
            matrix.append(temp)
        return matrix

    def sub(self, m1, m2, size):
        matrix = []
        for i in range(0, size):
            temp = []
            for j in range(0, size):
                temp.append(m1[i][j] - m2[i][j])
            matrix.append(temp)
        return matrix

    def multiply(self, m1, m2, size):
        if(size == 1):
            return [[m1[0][0] * m2[0][0]]]
        A11, A12, A21, A22 = division(m1)
        B11, B12, B21, B22 = division(m2)
        size = int(size / 2)


        #calculate M1 = A11(B12 - B22)
        m1 = self.multiply(A11, self.sub(B12, B22, size), size)

        #calculate M2 = (A11 + A12)B22
        m2 = self.multiply(self.add(A11, A12, size), B22, size)

        #calculate M3 = (A21 + A22)B11
        m3 = self.multiply(self.add(A21, A22, size), B11, size)

        #calculate M4 = A22(B21 - B11)
        m4 = self.multiply(A22, self.sub(B21, B11,size), size)

        #calculate M5 = (A11 + A22)(B11 + B22)
        m5 = self.multiply(self.add(A11, A22, size), self.add(B11, B22, size), size)

        #calculate M6 = (A12 - A22)(B21 + B22)
        m6 = self.multiply(self.sub(A12, A22, size), self.add(B21, B22, size), size)

        #calculate M7 = (A11 - A21)(B11 + B12)
        m7 = self.multiply(self.sub(A11, A21, size), self.add(B11, B12, size), size)

        #calculate C11 = M5 + M4 - M2 + M6
        C11 = self.add(self.sub(self.add(m5, m4, size), m2, size), m6, size)

        #calculate C12 = M1 + M2
        C12 = self.add(m1, m2, size)

        #calculate C21 = M3 + M4
        C21 = self.add(m3, m4, size)

        #calculate C22 = M5 + M1 - M3 -M7
        C22 = self.sub(self.sub(self.add(m5,m1,size), m3,size), m7,size)

        return mergeMatrix(C11, C12, C21, C22)


s = Strassen()
#print(s.sub([[1,2],[3,4]],[[1,2],[3,4]]))
matrixA = [
    [1, 1, 1, 1],
    [1, 2, 3, 4],
    [1, 2, 3, 4],
    [1, 2, 3, 4]
]
matrixB = [
    [1, 2, 3, 4],
    [1, 2, 3, 4],
    [1, 2, 3, 4],
    [1, 2, 3, 4]
]
print(s.multiply(matrixA, matrixB, 4))

【总结】

  • 在此问题中,相对于传统算法的复杂度\(O(n^3)\),使用分治+剪枝的Strassen算法的表现\(O(n^{2.81})\)更胜一筹。因此我们知道,在具有某些可以分治处理性质的问题中,也可以多利用分治法思想对问题进行求解。

  • 对矩阵乘法问题的研究中,Hopcroft和Kerr已证明,要计算2个2×2 矩阵的乘积,7次乘法是必要的。因此,要想进一步改进矩阵乘法的时间复杂性,就不能再基于计算2×2矩阵的7次乘法这样的方法了。或许应当研究3×3 或5×5矩阵的更好算法。

  • 在Strassen之后又有许多算法改进了矩阵乘法的计算时间 复杂性。目前最好的计算时间上界是 \(O(n^{2.376})\)。

  • 到目前为止仍无法确切知道矩阵乘法的时间复杂性,关于这一研究课题还有许多工作可做。

【参考文献】

  • 王红梅,胡明.《算法设计与分析》[M].清华大学出版社

  • 王晓东.《算法设计与分析(第五版)》[M].电子工业出版社

作者 WellLee

《【算法设计与分析】Strassen矩阵乘法(分治、剪枝)[附Python源码]》有2条评论

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注