C++ Strassen算法代码的实现

C++ Strassen算法代码的实现

什么是Strassen算法?

Strassen算法是一种矩阵乘法的优化算法,它将两个矩阵的乘法分解为若干个小矩阵的乘法,从而减少了矩阵乘法的计算次数。

具体来说,将两个$n\times n$的矩阵$A$和$B$分别划分成四个$\dfrac{n}{2}\times\dfrac{n}{2}$的矩阵:

$$
A = \begin{bmatrix}A_{11} &A_{12}\A_{21} &A_{22}\end{bmatrix},\quad B = \begin{bmatrix}B_{11} &B_{12}\B_{21} &B_{22}\end{bmatrix}
$$

然后可以通过以下公式计算矩阵$C=A\times B$:

$$
\begin{aligned}
C_{11}&=P_5+P_4-P_2+P_6\
C_{12}&=P_1+P_2\
C_{21}&=P_3+P_4\
C_{22}&=P_5+P_1-P_3-P_7
\end{aligned}
$$

其中,$P_1=A_{11}(B_{12}-B_{22})$,$P_2=(A_{11}+A_{12})B_{22}$,$P_3=(A_{21}+A_{22})B_{11}$,$P_4=A_{22}(B_{21}-B_{11})$,$P_5=(A_{11}+A_{22})(B_{11}+B_{22})$,$P_6=(A_{12}-A_{22})(B_{21}+B_{22})$,$P_7=(A_{11}-A_{21})(B_{11}+B_{12})$。

实现Strassen算法

为了实现Strassen算法,我们需要先实现一个能够执行矩阵加法和矩阵减法的函数,以及一个能够执行普通矩阵乘法的函数。具体代码如下:

#include <vector>
using std::vector;

// 矩阵加法
vector<vector<int>> add(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    int m = A[0].size();
    vector<vector<int>> C(n, vector<int>(m));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            C[i][j] = A[i][j] + B[i][j];
        }
    }
    return C;
}

// 矩阵减法
vector<vector<int>> sub(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    int m = A[0].size();
    vector<vector<int>> C(n, vector<int>(m));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            C[i][j] = A[i][j] - B[i][j];
        }
    }
    return C;
}

// 普通矩阵乘法
vector<vector<int>> multiply(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    int m = A[0].size();
    int l = B[0].size();
    vector<vector<int>> C(n, vector<int>(l));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < l; j++) {
            for (int k = 0; k < m; k++) {
                C[i][j] += A[i][k] * B[k][j];
            }
        }
    }
    return C;
}

然后可以根据上面的公式实现Strassen算法,代码如下:

// Strassen矩阵乘法
vector<vector<int>> strassen(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    vector<vector<int>> C(n, vector<int>(n));
    if (n == 1) {
        C[0][0] = A[0][0] * B[0][0];
    } else {
        vector<vector<int>> A11(n/2, vector<int>(n/2)), A12(n/2, vector<int>(n/2)), A21(n/2, vector<int>(n/2)), A22(n/2, vector<int>(n/2));
        vector<vector<int>> B11(n/2, vector<int>(n/2)), B12(n/2, vector<int>(n/2)), B21(n/2, vector<int>(n/2)), B22(n/2, vector<int>(n/2));
        vector<vector<int>> P1(n/2, vector<int>(n/2)), P2(n/2, vector<int>(n/2)), P3(n/2, vector<int>(n/2)), P4(n/2, vector<int>(n/2)), P5(n/2, vector<int>(n/2)), P6(n/2, vector<int>(n/2)), P7(n/2, vector<int>(n/2));
        // 将矩阵A和B分别划分成四个n/2*n/2的矩阵
        for (int i = 0; i < n/2; i++) {
            for (int j = 0; j < n/2; j++) {
                A11[i][j] = A[i][j];
                A12[i][j] = A[i][j+n/2];
                A21[i][j] = A[i+n/2][j];
                A22[i][j] = A[i+n/2][j+n/2];
                B11[i][j] = B[i][j];
                B12[i][j] = B[i][j+n/2];
                B21[i][j] = B[i+n/2][j];
                B22[i][j] = B[i+n/2][j+n/2];
            }
        }
        // 计算七个子问题
        P1 = strassen(A11, sub(B12, B22));
        P2 = strassen(add(A11, A12), B22);
        P3 = strassen(add(A21, A22), B11);
        P4 = strassen(A22, sub(B21, B11));
        P5 = strassen(add(A11, A22), add(B11, B22));
        P6 = strassen(sub(A12, A22), add(B21, B22));
        P7 = strassen(sub(A11, A21), add(B11, B12));
        // 计算结果矩阵C
        vector<vector<int>> C11(n/2, vector<int>(n/2)), C12(n/2, vector<int>(n/2)), C21(n/2, vector<int>(n/2)), C22(n/2, vector<int>(n/2));
        C11 = add(sub(add(P5, P4), P2), P6);
        C12 = add(P1, P2);
        C21 = add(P3, P4);
        C22 = add(sub(add(P5, P1), P3), P7);
        // 将四个矩阵合并成一个n*n的矩阵
        for (int i = 0; i < n/2; i++) {
            for (int j = 0; j < n/2; j++) {
                C[i][j] = C11[i][j];
                C[i][j+n/2] = C12[i][j];
                C[i+n/2][j] = C21[i][j];
                C[i+n/2][j+n/2] = C22[i][j];
            }
        }
    }
    return C;
}

示例说明

我们可以使用如下的代码对Strassen算法进行测试:

#include <iostream>
using std::cout;
using std::endl;
#include "strassen.h"

int main() {
    vector<vector<int>> A = {{1, 2}, {3, 4}};
    vector<vector<int>> B = {{5, 6}, {7, 8}};
    vector<vector<int>> C = multiply(A, B); // 普通矩阵乘法
    vector<vector<int>> D = strassen(A, B);  // Strassen矩阵乘法
    cout << "A * B =", endl;
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            cout << C[i][j] << " ";
        }
        cout << endl;
    }
    cout << "Strassen(A, B) =", endl;
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            cout << D[i][j] << " ";
        }
        cout << endl;
    }
    return 0;
}

输出结果为:

A * B =
19 22
43 50
Strassen(A, B) =
19 22
43 50

可以看出,Strassen算法得到的结果与普通矩阵乘法得到的结果相同。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:C++ Strassen算法代码的实现 - Python技术站

(0)
上一篇 2023年5月23日
下一篇 2023年5月23日

相关文章

  • C语言实现学生宿舍管理系统

    C语言实现学生宿舍管理系统攻略 1. 系统概述 学生宿舍管理系统是基于C语言实现的一个简单的管理系统。系统主要包括学生信息管理、宿舍信息管理和维修处理等模块。在学生信息管理模块中,学生可以登录系统并进行个人信息的修改、查看宿舍信息等操作。在宿舍信息管理模块中,管理员可以进行宿舍信息的添加、删除和修改等操作。在维修处理模块中,学生可以提交宿舍维修请求,并由管理…

    C 2023年5月23日
    00
  • 全民小镇2014万圣节活动介绍 全民小镇万圣节特殊海域和兑换券一览

    全民小镇2014万圣节活动介绍 活动时间 2014年10月25日-11月2日 活动内容 全民小镇万圣节活动分为两部分:特殊海域和兑换券。 特殊海域 特殊海域是活动期间新增的一些地图。在这些地图中,您将会遇到一些特殊的怪物和道具,同时还有不同于平常的地图场景,非常适合体验万圣节气氛。 兑换券 兑换券是您在活动中可以获得的奖励之一。在特定的NPC处,您可以用兑换…

    C 2023年5月22日
    00
  • 一篇文章带你了解C语言–数据的储存

    一篇文章带你了解C语言–数据的储存 在C语言中,数据的储存有三种方式:变量、数组和指针。 变量 变量是程序运行过程中储存数据的基本单位,它代表着一个内存地址,程序可以通过该地址访问该变量。 声明变量 在C语言中,变量的声明需要给出变量名和类型,如下: int a; float b; char c; 变量的赋值和读取 赋值使用等号“=”来实现,比如: a =…

    C 2023年5月23日
    00
  • Jmeter 使用Json提取请求数据的方法

    以下是详细讲解JMeter使用JSON提取请求数据的方法的完整攻略。 什么是JSON Extractor? JSON Extractor是JMeter插件之一,其主要功能是从HTTP响应中的JSON数据中提取出所需数据。 JSON Extractor配置 JSON Extractor是基于JMeter的post-processor,它可以获取JSON数据并在…

    C 2023年5月23日
    00
  • C语言中如何进行面向对象编程?

    在C语言中进行面向对象编程(Object-Oriented Programming)可以采用结构体(Struct)和指针(Pointer)的方式来实现。 首先,我们需要定义一个结构体,包含对象的属性和方法。属性可以使用变量来定义,方法可以使用函数指针来定义。例如: typedef struct { int x; int y; void (*draw)(voi…

    C 2023年4月27日
    00
  • 教你用Python为二年级的学生批量生成数学题

    我会提供一份完整的教程,教读者用Python批量生成数学题的过程。 1. 概述 在本次教程中,我们将使用Python编写程序来批量生成数学题。通过阅读本文,您将学会以下技能: 使用python实现数学运算 生成随机数 生成word文档并写入数据 2. 开始 如果你没有Python开发环境,你需要首先安装Python和需要的依赖包。我们在本教程中使用pytho…

    C 2023年5月22日
    00
  • C语言经典例程100例(经典c程序100例)

    简介 C语言经典例程100例是一本经典的C语言入门教材,在C语言的学习过程中,它是一本必不可少的参考书。本书由100个经典的C语言程序组成,涵盖了C语言程序的各个方面,不仅能帮助读者掌握C语言的基础知识,还能够提高读者的编程思维和实战能力。 攻略 (1)首先,阅读本书需要一定的基础知识,建议读者至少掌握C语言的基本语法、变量、运算符、控制语句和函数的使用方法…

    C 2023年5月23日
    00
  • 详解C/C++如何获取路径下所有文件及其子目录的文件名

    获取一个文件夹下的所有文件及其子目录的文件名可以通过递归遍历文件夹来完成。以下是几个示例代码,演示如何实现这个功能。 方法一:使用C++17中的std::filesystem 基于C++17标准,可以使用std::filesystem库来遍历目录。下面是示例代码: #include <iostream> #include <filesyst…

    C 2023年5月23日
    00
合作推广
合作推广
分享本页
返回顶部