C++ 中如何使用 constexpr 函数实现编译期矩阵乘法

在 C++20 及以后,constexpr 允许在编译期执行几乎所有可以在运行时执行的计算。利用这一特性,我们可以在编译期完成矩阵乘法,从而在程序运行时获得常量表达式,减少运行时开销。下面给出一个完整的实现示例,并解释关键点。

1. 设计矩阵类型

我们使用一个基于 std::array 的固定大小矩阵类型。为了保持灵活性,矩阵维度作为模板参数:

#include <array>
#include <cstddef>
#include <stdexcept>

template<std::size_t Rows, std::size_t Cols>
struct Matrix {
    std::array<std::array<double, Cols>, Rows> data{};

    constexpr double& operator()(std::size_t r, std::size_t c) {
        return data[r][c];
    }

    constexpr const double& operator()(std::size_t r, std::size_t c) const {
        return data[r][c];
    }
};
  • operator() 提供矩阵元素访问。
  • constexpr 让我们可以在编译期访问。

2. 生成编译期矩阵

我们需要一个函数,用于在编译期创建矩阵并填充值。常用方式是利用 std::initializer_list 或者递归模板。

template<std::size_t Rows, std::size_t Cols>
constexpr Matrix<Rows, Cols> make_matrix(const std::initializer_list<std::initializer_list<double>>& init) {
    Matrix<Rows, Cols> m{};
    std::size_t r = 0;
    for (auto& row : init) {
        std::size_t c = 0;
        for (auto& val : row) {
            if (r >= Rows || c >= Cols) throw std::out_of_range("Initializer size mismatch");
            m(r, c) = val;
            ++c;
        }
        ++r;
    }
    if (r != Rows) throw std::out_of_range("Initializer size mismatch");
    return m;
}

此函数在编译期执行,前提是传入的 init 也是 constexpr

3. constexpr 矩阵乘法

下面是核心:在编译期实现矩阵乘法。

template<std::size_t R, std::size_t K, std::size_t C>
constexpr Matrix<R, C> matmul(const Matrix<R, K>& A, const Matrix<K, C>& B) {
    Matrix<R, C> result{};

    for (std::size_t i = 0; i < R; ++i) {
        for (std::size_t j = 0; j < C; ++j) {
            double sum = 0.0;
            for (std::size_t k = 0; k < K; ++k) {
                sum += A(i, k) * B(k, j);
            }
            result(i, j) = sum;
        }
    }
    return result;
}
  • 所有循环都使用常量索引 size_t,可以在编译期展开。
  • sum 变量在编译期累加,符合 constexpr 规则。

4. 示例:编译期计算

constexpr auto A = make_matrix<2, 3>({
    {1, 2, 3},
    {4, 5, 6}
});

constexpr auto B = make_matrix<3, 2>({
    {7, 8},
    {9, 10},
    {11, 12}
});

constexpr auto C = matmul(A, B);  // 结果也是 constexpr

int main() {
    // C 已经在编译期计算完成
    for (std::size_t i = 0; i < 2; ++i) {
        for (std::size_t j = 0; j < 2; ++j) {
            std::cout << C(i, j) << ' ';
        }
        std::cout << '\n';
    }
}

编译时 C 已经是常量表达式,程序运行时只需打印预先计算好的结果。

5. 进一步优化

  1. 使用 constexpr 友好的算法

    • 如果矩阵较大,考虑改用矩阵块乘法或 Strassen 算法,以减少编译期时间。
  2. 利用模板元编程

    • 可以把矩阵大小作为类型参数,让编译器在类型层面完成计算,避免运行时循环。
  3. 使用 std::applyconstexpr lambda

    • 对于更复杂的初始化方式,可以用 constexpr lambda 生成矩阵。

6. 小结

  • constexpr 允许在编译期完成矩阵乘法,减少运行时开销。
  • 关键是保持所有操作(数组访问、循环、算术运算)都符合 constexpr 规则。
  • 通过 make_matrix 生成编译期矩阵,并使用 matmul 进行乘法,整个过程在编译阶段完成。

这套方案适用于需要大量矩阵运算且维度固定的嵌入式系统、游戏图形渲染以及其他对性能有苛刻要求的 C++ 项目。

发表评论