如何在C++中使用constexpr实现编译期矩阵乘法?

在 C++20 及更高版本中,constexpr 已经足够强大,能够在编译期完成几乎所有需要的计算。下面展示一种利用 constexprstd::array 以及模板元编程来实现矩阵乘法的完整示例。该实现可以在编译期完成大小已知的矩阵相乘,生成的结果直接作为常量可用于其他编译期计算或嵌入数据。

1. 设计思路

  • 矩阵表示:使用 std::array<std::array<T, N>, M> 来表示 M×N 矩阵,保证大小在编译期固定。
  • 行列提取:通过模板递归或 constexpr 函数提取指定行或列的 std::array
  • 乘法实现:利用行列的点积实现单个元素计算,随后构造整个结果矩阵。
  • 可变模板参数:通过 std::size_t... Is 生成编译期索引序列,简化元素访问。

2. 代码实现

#include <array>
#include <cstddef>
#include <utility>
#include <type_traits>

template <typename T, std::size_t R, std::size_t C>
using Matrix = std::array<std::array<T, C>, R>;

// 生成编译期索引序列
template<std::size_t... Is>
constexpr auto make_index_sequence(std::index_sequence<Is...>) {
    return std::array<std::size_t, sizeof...(Is)>{Is...};
}

// 取矩阵第 r 行
template<typename T, std::size_t R, std::size_t C, std::size_t r>
constexpr std::array<T, C> get_row(const Matrix<T, R, C>& m) {
    static_assert(r < R, "Row index out of bounds");
    return m[r];
}

// 取矩阵第 c 列
template<typename T, std::size_t R, std::size_t C, std::size_t c>
constexpr std::array<T, R> get_col(const Matrix<T, R, C>& m) {
    static_assert(c < C, "Column index out of bounds");
    std::array<T, R> col{};
    for (std::size_t i = 0; i < R; ++i)
        col[i] = m[i][c];
    return col;
}

// 计算两个向量的点积
template<typename T, std::size_t N, std::size_t... Is>
constexpr T dot_impl(const std::array<T, N>& a, const std::array<T, N>& b, std::index_sequence<Is...>) {
    return (a[Is] * b[Is] + ...);
}

template<typename T, std::size_t N>
constexpr T dot(const std::array<T, N>& a, const std::array<T, N>& b) {
    return dot_impl(a, b, std::make_index_sequence <N>{});
}

// 乘法
template<typename T, std::size_t R1, std::size_t C1, std::size_t R2, std::size_t C2>
constexpr Matrix<T, R1, C2> matmul(const Matrix<T, R1, C1>& A,
                                   const Matrix<T, R2, C2>& B) {
    static_assert(C1 == R2, "Inner dimensions must agree");
    Matrix<T, R1, C2> res{};
    for (std::size_t i = 0; i < R1; ++i) {
        for (std::size_t j = 0; j < C2; ++j) {
            auto row = get_row<T, R1, C1, i>(A);
            auto col = get_col<T, R2, C2, j>(B);
            res[i][j] = dot(row, col);
        }
    }
    return res;
}

3. 使用示例

constexpr Matrix<int, 2, 3> A{
    std::array<int, 3>{1, 2, 3},
    std::array<int, 3>{4, 5, 6}
};

constexpr Matrix<int, 3, 2> B{
    std::array<int, 2>{7, 8},
    std::array<int, 2>{9, 10},
    std::array<int, 2>{11, 12}
};

constexpr auto C = matmul(A, B);

// C 现在是编译期求得的 2x2 矩阵
static_assert(C[0][0] == 58, "C[0][0] should be 58");
static_assert(C[0][1] == 64, "C[0][1] should be 64");
static_assert(C[1][0] == 139, "C[1][0] should be 139");
static_assert(C[1][1] == 154, "C[1][1] should be 154");

4. 进一步扩展

  • 通用性:将 Matrix 改为 std::array<T, R*C> 并通过行列索引计算,实现更紧凑的存储。
  • 高阶矩阵运算:利用同样的 constexpr 思路实现转置、求逆、特征值分解等。
  • std::span 结合:在运行时处理不固定大小矩阵时,可以先在 constexpr 阶段生成常量,然后用 std::span 访问子矩阵。
  • 编译期调试:在需要验证矩阵乘法逻辑时,使用 static_assert 检查结果,避免运行时错误。

通过上述实现,C++20 的 constexpr 能够在编译期完成复杂的矩阵运算,为高性能嵌入式系统、编译期生成的数据表等场景提供了强有力的工具。

发表评论