C++20 模板元编程实战:编译期实现通用矩阵乘法

在现代 C++(尤其是 C++20)中,模板元编程(TMP)与概念(concepts)结合,为我们提供了强大的编译期计算能力。本文将展示如何利用 TMP 在编译期完成矩阵乘法的实现,并通过概念确保类型安全与维度一致性。目标是让你在不使用运行时循环的前提下,得到一段只在编译期完成的、可直接使用的矩阵乘法代码。

1. 先决条件

  • C++20 编译器(支持 std::spanconstevalconcept 等特性)
  • 了解基础矩阵存储方式:行优先(row-major)或列优先(column-major)
  • 对模板递归和 constexpr 机制有一定了解

2. 设计思路

  1. 矩阵类型:使用 std::array 作为底层容器,配合 std::size_t 的编译期常量来标识行列数。
  2. 概念:通过 concept 约束矩阵的行列数以及乘法前后维度一致性。
  3. 递归乘法:利用模板递归,逐行逐列计算矩阵乘积。
  4. 编译期求值:所有运算均在 constexpr 环境下完成,编译器将生成展开后的常量表达式。

3. 代码实现

#include <array>
#include <concepts>
#include <cstddef>
#include <iostream>

/* ==================== 1. 矩阵定义 ==================== */
template <typename T, std::size_t R, std::size_t C>
using Matrix = std::array<std::array<T, C>, R>;

/* ==================== 2. 概念约束 ==================== */
template <typename T, std::size_t R1, std::size_t C1, std::size_t R2, std::size_t C2>
concept MatrixMulable =
    requires(Matrix<T, R1, C1> a, Matrix<T, R2, C2> b) {
        { C1 == R2 }; // 内积维度相等
    };

/* ==================== 3. 编译期乘法实现 ==================== */
namespace detail {
    // 单个元素的乘积
    template <typename T, std::size_t R, std::size_t C>
    constexpr T dot(const Matrix<T, R, C>& a, const Matrix<T, C, R>& b,
                    std::size_t row, std::size_t col) {
        T sum{};
        for (std::size_t k = 0; k < C; ++k)
            sum += a[row][k] * b[k][col];
        return sum;
    }

    // 计算结果矩阵的每一行
    template <typename T, std::size_t R1, std::size_t C1, std::size_t R2, std::size_t C2,
              std::size_t Row>
    constexpr std::array<T, C2> row_mul(const Matrix<T, R1, C1>& a,
                                        const Matrix<T, R2, C2>& b) {
        std::array<T, C2> result{};
        for (std::size_t col = 0; col < C2; ++col)
            result[col] = dot(a, b, Row, col);
        return result;
    }
}

/* ==================== 4. 整体乘法 ==================== */
template <typename T, std::size_t R1, std::size_t C1,
          std::size_t R2, std::size_t C2>
requires MatrixMulable<T, R1, C1, R2, C2>
constexpr Matrix<T, R1, C2> matmul(const Matrix<T, R1, C1>& a,
                                   const Matrix<T, R2, C2>& b) {
    Matrix<T, R1, C2> result{};
    for (std::size_t r = 0; r < R1; ++r)
        result[r] = detail::row_mul<T, R1, C1, R2, C2, r>(a, b);
    return result;
}

/* ==================== 5. 示例 ==================== */
int main() {
    constexpr Matrix<int, 2, 3> A{ std::array<int, 3>{1, 2, 3},
                                   std::array<int, 3>{4, 5, 6} };
    constexpr Matrix<int, 3, 2> B{ std::array<int, 2>{7, 8},
                                   std::array<int, 2>{9, 10},
                                   std::array<int, 2>{11, 12} };

    constexpr auto C = matmul(A, B); // C 的类型为 Matrix<int, 2, 2>

    // 在运行时输出结果,验证正确性
    std::cout << "C = [";
    for (std::size_t i = 0; i < 2; ++i) {
        std::cout << "[";
        for (std::size_t j = 0; j < 2; ++j) {
            std::cout << C[i][j];
            if (j < 1) std::cout << ", ";
        }
        std::cout << "]";
        if (i < 1) std::cout << ", ";
    }
    std::cout << "]\n";
}

代码说明

  1. Matrix:以二维 std::array 表示矩阵,大小在编译期确定。
  2. MatrixMulable:确保乘法时左矩阵列数等于右矩阵行数。
  3. dot:计算单个元素乘积的和,完全在编译期完成。
  4. row_mul:生成结果矩阵的每一行。
  5. matmul:主函数,循环行数调用 row_mul,返回最终矩阵。

4. 编译期求值的优势

  • 性能:所有乘法在编译期展开,无运行时循环,最终代码仅为常量加载。
  • 类型安全:概念保证维度兼容,编译器立即报错。
  • 可读性:与传统运行时实现保持相似接口,易于迁移。

5. 扩展思路

  • 对称矩阵:添加 is_square 概念,优化对角线计算。
  • 转置:利用 TMP 生成转置矩阵的编译期版本。
  • 稀疏矩阵:结合 std::vector<std::pair<std::size_t, T>> 进行稀疏乘法。

6. 结语

本文演示了如何在 C++20 中使用模板元编程与概念,实现纯编译期的矩阵乘法。虽然示例仅演示了二维矩阵,但思路可以推广到更高维张量的运算。借助编译期计算,你可以构建高性能、类型安全的数学库,让编译器帮你完成一部分工作,释放运行时的资源。祝你编码愉快!

发表评论