C++ Strassen算法代码的实现
什么是Strassen算法?
Strassen算法是一种矩阵乘法的优化算法,它将两个矩阵的乘法分解为若干个小矩阵的乘法,从而减少了矩阵乘法的计算次数。
具体来说,将两个$n\times n$的矩阵$A$和$B$分别划分成四个$\dfrac{n}{2}\times\dfrac{n}{2}$的矩阵:
$$
A = \begin{bmatrix}A_{11} &A_{12}\A_{21} &A_{22}\end{bmatrix},\quad B = \begin{bmatrix}B_{11} &B_{12}\B_{21} &B_{22}\end{bmatrix}
$$
然后可以通过以下公式计算矩阵$C=A\times B$:
$$
\begin{aligned}
C_{11}&=P_5+P_4-P_2+P_6\
C_{12}&=P_1+P_2\
C_{21}&=P_3+P_4\
C_{22}&=P_5+P_1-P_3-P_7
\end{aligned}
$$
其中,$P_1=A_{11}(B_{12}-B_{22})$,$P_2=(A_{11}+A_{12})B_{22}$,$P_3=(A_{21}+A_{22})B_{11}$,$P_4=A_{22}(B_{21}-B_{11})$,$P_5=(A_{11}+A_{22})(B_{11}+B_{22})$,$P_6=(A_{12}-A_{22})(B_{21}+B_{22})$,$P_7=(A_{11}-A_{21})(B_{11}+B_{12})$。
实现Strassen算法
为了实现Strassen算法,我们需要先实现一个能够执行矩阵加法和矩阵减法的函数,以及一个能够执行普通矩阵乘法的函数。具体代码如下:
#include <vector>
using std::vector;
// 矩阵加法
vector<vector<int>> add(const vector<vector<int>>& A, const vector<vector<int>>& B) {
int n = A.size();
int m = A[0].size();
vector<vector<int>> C(n, vector<int>(m));
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
// 矩阵减法
vector<vector<int>> sub(const vector<vector<int>>& A, const vector<vector<int>>& B) {
int n = A.size();
int m = A[0].size();
vector<vector<int>> C(n, vector<int>(m));
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
// 普通矩阵乘法
vector<vector<int>> multiply(const vector<vector<int>>& A, const vector<vector<int>>& B) {
int n = A.size();
int m = A[0].size();
int l = B[0].size();
vector<vector<int>> C(n, vector<int>(l));
for (int i = 0; i < n; i++) {
for (int j = 0; j < l; j++) {
for (int k = 0; k < m; k++) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
return C;
}
然后可以根据上面的公式实现Strassen算法,代码如下:
// Strassen矩阵乘法
vector<vector<int>> strassen(const vector<vector<int>>& A, const vector<vector<int>>& B) {
int n = A.size();
vector<vector<int>> C(n, vector<int>(n));
if (n == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
vector<vector<int>> A11(n/2, vector<int>(n/2)), A12(n/2, vector<int>(n/2)), A21(n/2, vector<int>(n/2)), A22(n/2, vector<int>(n/2));
vector<vector<int>> B11(n/2, vector<int>(n/2)), B12(n/2, vector<int>(n/2)), B21(n/2, vector<int>(n/2)), B22(n/2, vector<int>(n/2));
vector<vector<int>> P1(n/2, vector<int>(n/2)), P2(n/2, vector<int>(n/2)), P3(n/2, vector<int>(n/2)), P4(n/2, vector<int>(n/2)), P5(n/2, vector<int>(n/2)), P6(n/2, vector<int>(n/2)), P7(n/2, vector<int>(n/2));
// 将矩阵A和B分别划分成四个n/2*n/2的矩阵
for (int i = 0; i < n/2; i++) {
for (int j = 0; j < n/2; j++) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][j+n/2];
A21[i][j] = A[i+n/2][j];
A22[i][j] = A[i+n/2][j+n/2];
B11[i][j] = B[i][j];
B12[i][j] = B[i][j+n/2];
B21[i][j] = B[i+n/2][j];
B22[i][j] = B[i+n/2][j+n/2];
}
}
// 计算七个子问题
P1 = strassen(A11, sub(B12, B22));
P2 = strassen(add(A11, A12), B22);
P3 = strassen(add(A21, A22), B11);
P4 = strassen(A22, sub(B21, B11));
P5 = strassen(add(A11, A22), add(B11, B22));
P6 = strassen(sub(A12, A22), add(B21, B22));
P7 = strassen(sub(A11, A21), add(B11, B12));
// 计算结果矩阵C
vector<vector<int>> C11(n/2, vector<int>(n/2)), C12(n/2, vector<int>(n/2)), C21(n/2, vector<int>(n/2)), C22(n/2, vector<int>(n/2));
C11 = add(sub(add(P5, P4), P2), P6);
C12 = add(P1, P2);
C21 = add(P3, P4);
C22 = add(sub(add(P5, P1), P3), P7);
// 将四个矩阵合并成一个n*n的矩阵
for (int i = 0; i < n/2; i++) {
for (int j = 0; j < n/2; j++) {
C[i][j] = C11[i][j];
C[i][j+n/2] = C12[i][j];
C[i+n/2][j] = C21[i][j];
C[i+n/2][j+n/2] = C22[i][j];
}
}
}
return C;
}
示例说明
我们可以使用如下的代码对Strassen算法进行测试:
#include <iostream>
using std::cout;
using std::endl;
#include "strassen.h"
int main() {
vector<vector<int>> A = {{1, 2}, {3, 4}};
vector<vector<int>> B = {{5, 6}, {7, 8}};
vector<vector<int>> C = multiply(A, B); // 普通矩阵乘法
vector<vector<int>> D = strassen(A, B); // Strassen矩阵乘法
cout << "A * B =", endl;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
cout << C[i][j] << " ";
}
cout << endl;
}
cout << "Strassen(A, B) =", endl;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
cout << D[i][j] << " ";
}
cout << endl;
}
return 0;
}
输出结果为:
A * B =
19 22
43 50
Strassen(A, B) =
19 22
43 50
可以看出,Strassen算法得到的结果与普通矩阵乘法得到的结果相同。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:C++ Strassen算法代码的实现 - Python技术站