在现代 C++(尤其是 C++20)中,模板元编程(TMP)与概念(concepts)结合,为我们提供了强大的编译期计算能力。本文将展示如何利用 TMP 在编译期完成矩阵乘法的实现,并通过概念确保类型安全与维度一致性。目标是让你在不使用运行时循环的前提下,得到一段只在编译期完成的、可直接使用的矩阵乘法代码。
1. 先决条件
- C++20 编译器(支持
std::span、consteval、concept等特性) - 了解基础矩阵存储方式:行优先(row-major)或列优先(column-major)
- 对模板递归和 constexpr 机制有一定了解
2. 设计思路
- 矩阵类型:使用
std::array作为底层容器,配合std::size_t的编译期常量来标识行列数。 - 概念:通过
concept约束矩阵的行列数以及乘法前后维度一致性。 - 递归乘法:利用模板递归,逐行逐列计算矩阵乘积。
- 编译期求值:所有运算均在
constexpr环境下完成,编译器将生成展开后的常量表达式。
3. 代码实现
#include <array>
#include <concepts>
#include <cstddef>
#include <iostream>
/* ==================== 1. 矩阵定义 ==================== */
template <typename T, std::size_t R, std::size_t C>
using Matrix = std::array<std::array<T, C>, R>;
/* ==================== 2. 概念约束 ==================== */
template <typename T, std::size_t R1, std::size_t C1, std::size_t R2, std::size_t C2>
concept MatrixMulable =
requires(Matrix<T, R1, C1> a, Matrix<T, R2, C2> b) {
{ C1 == R2 }; // 内积维度相等
};
/* ==================== 3. 编译期乘法实现 ==================== */
namespace detail {
// 单个元素的乘积
template <typename T, std::size_t R, std::size_t C>
constexpr T dot(const Matrix<T, R, C>& a, const Matrix<T, C, R>& b,
std::size_t row, std::size_t col) {
T sum{};
for (std::size_t k = 0; k < C; ++k)
sum += a[row][k] * b[k][col];
return sum;
}
// 计算结果矩阵的每一行
template <typename T, std::size_t R1, std::size_t C1, std::size_t R2, std::size_t C2,
std::size_t Row>
constexpr std::array<T, C2> row_mul(const Matrix<T, R1, C1>& a,
const Matrix<T, R2, C2>& b) {
std::array<T, C2> result{};
for (std::size_t col = 0; col < C2; ++col)
result[col] = dot(a, b, Row, col);
return result;
}
}
/* ==================== 4. 整体乘法 ==================== */
template <typename T, std::size_t R1, std::size_t C1,
std::size_t R2, std::size_t C2>
requires MatrixMulable<T, R1, C1, R2, C2>
constexpr Matrix<T, R1, C2> matmul(const Matrix<T, R1, C1>& a,
const Matrix<T, R2, C2>& b) {
Matrix<T, R1, C2> result{};
for (std::size_t r = 0; r < R1; ++r)
result[r] = detail::row_mul<T, R1, C1, R2, C2, r>(a, b);
return result;
}
/* ==================== 5. 示例 ==================== */
int main() {
constexpr Matrix<int, 2, 3> A{ std::array<int, 3>{1, 2, 3},
std::array<int, 3>{4, 5, 6} };
constexpr Matrix<int, 3, 2> B{ std::array<int, 2>{7, 8},
std::array<int, 2>{9, 10},
std::array<int, 2>{11, 12} };
constexpr auto C = matmul(A, B); // C 的类型为 Matrix<int, 2, 2>
// 在运行时输出结果,验证正确性
std::cout << "C = [";
for (std::size_t i = 0; i < 2; ++i) {
std::cout << "[";
for (std::size_t j = 0; j < 2; ++j) {
std::cout << C[i][j];
if (j < 1) std::cout << ", ";
}
std::cout << "]";
if (i < 1) std::cout << ", ";
}
std::cout << "]\n";
}
代码说明
Matrix:以二维std::array表示矩阵,大小在编译期确定。MatrixMulable:确保乘法时左矩阵列数等于右矩阵行数。dot:计算单个元素乘积的和,完全在编译期完成。row_mul:生成结果矩阵的每一行。matmul:主函数,循环行数调用row_mul,返回最终矩阵。
4. 编译期求值的优势
- 性能:所有乘法在编译期展开,无运行时循环,最终代码仅为常量加载。
- 类型安全:概念保证维度兼容,编译器立即报错。
- 可读性:与传统运行时实现保持相似接口,易于迁移。
5. 扩展思路
- 对称矩阵:添加
is_square概念,优化对角线计算。 - 转置:利用 TMP 生成转置矩阵的编译期版本。
- 稀疏矩阵:结合
std::vector<std::pair<std::size_t, T>>进行稀疏乘法。
6. 结语
本文演示了如何在 C++20 中使用模板元编程与概念,实现纯编译期的矩阵乘法。虽然示例仅演示了二维矩阵,但思路可以推广到更高维张量的运算。借助编译期计算,你可以构建高性能、类型安全的数学库,让编译器帮你完成一部分工作,释放运行时的资源。祝你编码愉快!