C++ Strassen算法代码的实现

C++ Strassen算法代码的实现

什么是Strassen算法?

Strassen算法是一种矩阵乘法的优化算法,它将两个矩阵的乘法分解为若干个小矩阵的乘法,从而减少了矩阵乘法的计算次数。

具体来说,将两个$n\times n$的矩阵$A$和$B$分别划分成四个$\dfrac{n}{2}\times\dfrac{n}{2}$的矩阵:

$$
A = \begin{bmatrix}A_{11} &A_{12}\A_{21} &A_{22}\end{bmatrix},\quad B = \begin{bmatrix}B_{11} &B_{12}\B_{21} &B_{22}\end{bmatrix}
$$

然后可以通过以下公式计算矩阵$C=A\times B$:

$$
\begin{aligned}
C_{11}&=P_5+P_4-P_2+P_6\
C_{12}&=P_1+P_2\
C_{21}&=P_3+P_4\
C_{22}&=P_5+P_1-P_3-P_7
\end{aligned}
$$

其中,$P_1=A_{11}(B_{12}-B_{22})$,$P_2=(A_{11}+A_{12})B_{22}$,$P_3=(A_{21}+A_{22})B_{11}$,$P_4=A_{22}(B_{21}-B_{11})$,$P_5=(A_{11}+A_{22})(B_{11}+B_{22})$,$P_6=(A_{12}-A_{22})(B_{21}+B_{22})$,$P_7=(A_{11}-A_{21})(B_{11}+B_{12})$。

实现Strassen算法

为了实现Strassen算法,我们需要先实现一个能够执行矩阵加法和矩阵减法的函数,以及一个能够执行普通矩阵乘法的函数。具体代码如下:

#include <vector>
using std::vector;

// 矩阵加法
vector<vector<int>> add(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    int m = A[0].size();
    vector<vector<int>> C(n, vector<int>(m));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            C[i][j] = A[i][j] + B[i][j];
        }
    }
    return C;
}

// 矩阵减法
vector<vector<int>> sub(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    int m = A[0].size();
    vector<vector<int>> C(n, vector<int>(m));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            C[i][j] = A[i][j] - B[i][j];
        }
    }
    return C;
}

// 普通矩阵乘法
vector<vector<int>> multiply(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    int m = A[0].size();
    int l = B[0].size();
    vector<vector<int>> C(n, vector<int>(l));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < l; j++) {
            for (int k = 0; k < m; k++) {
                C[i][j] += A[i][k] * B[k][j];
            }
        }
    }
    return C;
}

然后可以根据上面的公式实现Strassen算法,代码如下:

// Strassen矩阵乘法
vector<vector<int>> strassen(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    vector<vector<int>> C(n, vector<int>(n));
    if (n == 1) {
        C[0][0] = A[0][0] * B[0][0];
    } else {
        vector<vector<int>> A11(n/2, vector<int>(n/2)), A12(n/2, vector<int>(n/2)), A21(n/2, vector<int>(n/2)), A22(n/2, vector<int>(n/2));
        vector<vector<int>> B11(n/2, vector<int>(n/2)), B12(n/2, vector<int>(n/2)), B21(n/2, vector<int>(n/2)), B22(n/2, vector<int>(n/2));
        vector<vector<int>> P1(n/2, vector<int>(n/2)), P2(n/2, vector<int>(n/2)), P3(n/2, vector<int>(n/2)), P4(n/2, vector<int>(n/2)), P5(n/2, vector<int>(n/2)), P6(n/2, vector<int>(n/2)), P7(n/2, vector<int>(n/2));
        // 将矩阵A和B分别划分成四个n/2*n/2的矩阵
        for (int i = 0; i < n/2; i++) {
            for (int j = 0; j < n/2; j++) {
                A11[i][j] = A[i][j];
                A12[i][j] = A[i][j+n/2];
                A21[i][j] = A[i+n/2][j];
                A22[i][j] = A[i+n/2][j+n/2];
                B11[i][j] = B[i][j];
                B12[i][j] = B[i][j+n/2];
                B21[i][j] = B[i+n/2][j];
                B22[i][j] = B[i+n/2][j+n/2];
            }
        }
        // 计算七个子问题
        P1 = strassen(A11, sub(B12, B22));
        P2 = strassen(add(A11, A12), B22);
        P3 = strassen(add(A21, A22), B11);
        P4 = strassen(A22, sub(B21, B11));
        P5 = strassen(add(A11, A22), add(B11, B22));
        P6 = strassen(sub(A12, A22), add(B21, B22));
        P7 = strassen(sub(A11, A21), add(B11, B12));
        // 计算结果矩阵C
        vector<vector<int>> C11(n/2, vector<int>(n/2)), C12(n/2, vector<int>(n/2)), C21(n/2, vector<int>(n/2)), C22(n/2, vector<int>(n/2));
        C11 = add(sub(add(P5, P4), P2), P6);
        C12 = add(P1, P2);
        C21 = add(P3, P4);
        C22 = add(sub(add(P5, P1), P3), P7);
        // 将四个矩阵合并成一个n*n的矩阵
        for (int i = 0; i < n/2; i++) {
            for (int j = 0; j < n/2; j++) {
                C[i][j] = C11[i][j];
                C[i][j+n/2] = C12[i][j];
                C[i+n/2][j] = C21[i][j];
                C[i+n/2][j+n/2] = C22[i][j];
            }
        }
    }
    return C;
}

示例说明

我们可以使用如下的代码对Strassen算法进行测试:

#include <iostream>
using std::cout;
using std::endl;
#include "strassen.h"

int main() {
    vector<vector<int>> A = {{1, 2}, {3, 4}};
    vector<vector<int>> B = {{5, 6}, {7, 8}};
    vector<vector<int>> C = multiply(A, B); // 普通矩阵乘法
    vector<vector<int>> D = strassen(A, B);  // Strassen矩阵乘法
    cout << "A * B =", endl;
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            cout << C[i][j] << " ";
        }
        cout << endl;
    }
    cout << "Strassen(A, B) =", endl;
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            cout << D[i][j] << " ";
        }
        cout << endl;
    }
    return 0;
}

输出结果为:

A * B =
19 22
43 50
Strassen(A, B) =
19 22
43 50

可以看出,Strassen算法得到的结果与普通矩阵乘法得到的结果相同。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:C++ Strassen算法代码的实现 - Python技术站

(0)
上一篇 2023年5月23日
下一篇 2023年5月23日

相关文章

  • C语言中设置用户识别码的相关函数的简单讲解

    下面是关于C语言中设置用户识别码相关函数的简要讲解: 什么是用户识别码? 用户识别码是一种数字标识符,用于标识和区分不同的用户。在操作系统中,每个用户都有一个独特的用户识别码(UID),操作系统根据用户识别码来识别用户,以控制对资源的访问权限。 C语言中设置用户识别码的函数 在C语言中,可以使用以下函数设置当前进程的用户识别码(UID)。这些函数定义在 &l…

    C 2023年5月23日
    00
  • 关于C语言程序的内存分配的入门知识学习

    关于C语言程序的内存分配的入门知识,要了解到以下内容: 1. 内存的基本概念 计算机是由中央处理器(CPU)、内存和硬盘等电子装置组成的。内存是程序运行时存储数据和代码的临时存储器,程序每次运行都需要占用内存,当程序结束后就会释放相应的内存。 2. 栈与堆的比较 在程序中,常见的内存分配方式有栈和堆两种,它们都是存储数据的区域,但其具体的使用方式有所不同。-…

    C 2023年5月23日
    00
  • Spring Boot全局异常处理解析

    下面是关于Spring Boot全局异常处理解析的完整攻略,包括了详细的讲解和示例说明。 什么是全局异常处理 在 Spring Boot 中,我们可以使用 @ControllerAdvice 注解来定义一些全局的异常处理方法,这些方法可以捕获到应用程序中可能出现的异常,并进行特定的处理。全局异常处理能够提供更友好的错误信息,方便开发人员和用户进行错误排查和解…

    C 2023年5月23日
    00
  • 升级Win8.1后传统start开始菜单不见了如何找回

    针对“升级Win8.1后传统start开始菜单不见了如何找回”的问题,我来给出完整的攻略: 问题描述 在升级Windows 8.1之后,原本存在的传统start开始菜单不见了,这该如何找回? 解决步骤 1. 检查任务栏设置 有时传统start开始菜单的隐藏可能是由于任务栏设置所导致的。可以按照以下步骤进行设置: 鼠标右键点击任务栏,并选择“属性”选项; 在弹…

    C 2023年5月24日
    00
  • C语言自定义类型详解(结构体、枚举、联合体和位段)

    C语言自定义类型详解 C语言中自定义类型是构建代码结构的关键组成部分。一个程序中定义的自定义类型,可以用来描述程序中的状态和数据,使程序更加清晰和易于维护。C语言中的自定义类型有结构体、枚举、联合体和位段等。本文将为大家详细讲解C语言中这四种自定义类型的使用和应用场景。 结构体 定义结构体 结构体是用于存储多个不同数据类型的变量的自定义类型。例如,一个保存学…

    C 2023年5月23日
    00
  • C语言实现中国象棋

    题目:C语言实现中国象棋 这是一个将中国象棋的游戏规则用C语言实现的项目。下面是实现该项目的完整攻略: 1. 确定需要的数据结构 在编写代码之前,需要确定需要的数据结构。对于中国象棋,我们可以使用以下数据结构: 棋子(soldier): 数字编号 棋子颜色(红色或黑色) 棋子类型(如马、象、帅等) 棋子当前所在位置 棋子是否被吃掉 棋盘(board): 二维…

    C 2023年5月23日
    00
  • 整型数据在内存中存储方式的讲解

    当我们声明一个整型变量时,计算机会在内存中分配一段连续的存储空间来存储该变量的值。在C语言中,整型数据的存储空间占用长度是根据数据类型决定的,在32位系统中一般为4字节(32位),在64位系统中一般为8字节(64位)。 整型数据在内存中存储方式是使用二进制补码表示。 二进制补码是一种表示有符号整数的方法,它对一个数的正负没有区别,而且在计算机中操作速度更快,…

    C 2023年5月23日
    00
  • C语言栈顺序结构实现代码

    下面我将详细讲解如何用 C 语言实现栈的顺序结构并提供两个示例。 什么是栈? 栈是一种数据结构,特点是 Last In First Out (LIFO) 后进先出。栈具有两个基本操作:压入(push)和弹出(pop)。在栈的顺序结构中,栈被定义为一个固定大小的数组,其中有一个指针(top)指向栈的顶部元素。 栈的顺序结构实现 首先,我们需要定义栈的数据结构,…

    C 2023年5月23日
    00
合作推广
合作推广
分享本页
返回顶部