在 C++20 及更高版本中,constexpr 已经足够强大,能够在编译期完成几乎所有需要的计算。下面展示一种利用 constexpr、std::array 以及模板元编程来实现矩阵乘法的完整示例。该实现可以在编译期完成大小已知的矩阵相乘,生成的结果直接作为常量可用于其他编译期计算或嵌入数据。
1. 设计思路
- 矩阵表示:使用
std::array<std::array<T, N>, M>来表示M×N矩阵,保证大小在编译期固定。 - 行列提取:通过模板递归或
constexpr函数提取指定行或列的std::array。 - 乘法实现:利用行列的点积实现单个元素计算,随后构造整个结果矩阵。
- 可变模板参数:通过
std::size_t... Is生成编译期索引序列,简化元素访问。
2. 代码实现
#include <array>
#include <cstddef>
#include <utility>
#include <type_traits>
template <typename T, std::size_t R, std::size_t C>
using Matrix = std::array<std::array<T, C>, R>;
// 生成编译期索引序列
template<std::size_t... Is>
constexpr auto make_index_sequence(std::index_sequence<Is...>) {
return std::array<std::size_t, sizeof...(Is)>{Is...};
}
// 取矩阵第 r 行
template<typename T, std::size_t R, std::size_t C, std::size_t r>
constexpr std::array<T, C> get_row(const Matrix<T, R, C>& m) {
static_assert(r < R, "Row index out of bounds");
return m[r];
}
// 取矩阵第 c 列
template<typename T, std::size_t R, std::size_t C, std::size_t c>
constexpr std::array<T, R> get_col(const Matrix<T, R, C>& m) {
static_assert(c < C, "Column index out of bounds");
std::array<T, R> col{};
for (std::size_t i = 0; i < R; ++i)
col[i] = m[i][c];
return col;
}
// 计算两个向量的点积
template<typename T, std::size_t N, std::size_t... Is>
constexpr T dot_impl(const std::array<T, N>& a, const std::array<T, N>& b, std::index_sequence<Is...>) {
return (a[Is] * b[Is] + ...);
}
template<typename T, std::size_t N>
constexpr T dot(const std::array<T, N>& a, const std::array<T, N>& b) {
return dot_impl(a, b, std::make_index_sequence <N>{});
}
// 乘法
template<typename T, std::size_t R1, std::size_t C1, std::size_t R2, std::size_t C2>
constexpr Matrix<T, R1, C2> matmul(const Matrix<T, R1, C1>& A,
const Matrix<T, R2, C2>& B) {
static_assert(C1 == R2, "Inner dimensions must agree");
Matrix<T, R1, C2> res{};
for (std::size_t i = 0; i < R1; ++i) {
for (std::size_t j = 0; j < C2; ++j) {
auto row = get_row<T, R1, C1, i>(A);
auto col = get_col<T, R2, C2, j>(B);
res[i][j] = dot(row, col);
}
}
return res;
}
3. 使用示例
constexpr Matrix<int, 2, 3> A{
std::array<int, 3>{1, 2, 3},
std::array<int, 3>{4, 5, 6}
};
constexpr Matrix<int, 3, 2> B{
std::array<int, 2>{7, 8},
std::array<int, 2>{9, 10},
std::array<int, 2>{11, 12}
};
constexpr auto C = matmul(A, B);
// C 现在是编译期求得的 2x2 矩阵
static_assert(C[0][0] == 58, "C[0][0] should be 58");
static_assert(C[0][1] == 64, "C[0][1] should be 64");
static_assert(C[1][0] == 139, "C[1][0] should be 139");
static_assert(C[1][1] == 154, "C[1][1] should be 154");
4. 进一步扩展
- 通用性:将
Matrix改为std::array<T, R*C>并通过行列索引计算,实现更紧凑的存储。 - 高阶矩阵运算:利用同样的
constexpr思路实现转置、求逆、特征值分解等。 - 与
std::span结合:在运行时处理不固定大小矩阵时,可以先在constexpr阶段生成常量,然后用std::span访问子矩阵。 - 编译期调试:在需要验证矩阵乘法逻辑时,使用
static_assert检查结果,避免运行时错误。
通过上述实现,C++20 的 constexpr 能够在编译期完成复杂的矩阵运算,为高性能嵌入式系统、编译期生成的数据表等场景提供了强有力的工具。