利用Tensor Core加速矩阵乘法的代码之理解

Tensor Core单元可以实现矩阵乘法的加速,也为CUDA核提供了调用接口。我最近学习了相关文档,也读到官方给出的矩阵乘法样例代码,在此记录自己的经验和理解。
常见的矩阵乘法有两种:D = A * B + C 与 C = A * B + C,两者之间的区别是后者原地计算。
这里预先约定好矩阵形状的代表符号,A[M, K], B[K, N],其中M是A的行数,K是A的列数,N是B的列数,于是容易推算出D[M, N], C[M, N],这些符号需要记住。
在tensor core之前,cuda kernel核函数的矩阵运算如下图所示,将A在y方向上分割成M / m块[m, K]大小的矩阵,同样将B在x方向上分割成多块N / n块[n, K]大小的矩阵,于是我们只需要申请 M / m * N / n个线程块即可对A*B并行计算,具体实现来说,可以申请dim3 blockSize(n, m)的线程块与dim3 gridSize(N / n, M / m)的网格。

而对于每个线程块,在K方向上也可以划分成K / k个小矩阵,A中每个小矩阵的形状为[m, k],于是矩阵乘法变成了更小的子矩阵乘法,1*1 + 2*2 + 3*3...即可得到我们想要的结果,可以在kernel核函数中用循环实现这种计算。

这种矩阵乘法的实现比较常见,在cuda基础教程中有代码实现。
说回Tensor Core,其加速矩阵乘法与上述的思路类似,但我们需要先了解一下其硬件特性。与FP32 Core类似,Tensor Core就是一个运算单元,前者输入两个浮点数,返回一个浮点数加法结果,后者输入两个矩阵,返回矩阵乘法结果。在cuda C的tensor core接口(wmma)中,kernel核函数中一次tensor core的运算需要占用一个warp的线程(32个)。由于tensor core的一次运算的矩阵大小是固定的,所需线程数也是固定的,所以我们多个tensor core并行运算只需要对矩阵、线程进行分割即可,下面讲讲怎么分割。
假设tensor core的一次矩阵运算的形状为[m, k] * [k, n] = [m, n],其中从A矩阵中分割出[m, k]的子矩阵,从B矩阵分割出[k, n]的子矩阵,得到一个[m, n]的子矩阵。通过简单的计算可得,A矩阵要求在y方向上需要M / m个warp的线程(每个warp负责[m, k]的矩阵),B矩阵要求在x方向上需要N / n个warp的线程,而在kernel内进行K / k次的循环累加即可得到C中[m, n]的子矩阵。如果你熟悉之前的矩阵乘法,这一定不难想明白。

剩下的就是编程了:
首先预定义__CUDACC__,其实不做预定义也能编译成功,但VS不会出现代码的输入补全提示,而且满屏波浪号。

然后是初始化矩阵,这里A的内容是1, 2, 3, 4....512的序列,B的元素全是1,C的元素全是0,由于tensor core不接受float的输入,所以使用半精度half作为输入,float作为输出。

最后是定义tensor core接收矩阵形状的大小核函数了

总之,使用tensor core的矩阵乘法与普通的矩阵乘法其实是类似的,只不过tensor core的运算粒度更大,吞吐量更高。
完整代码如下:
#include<device_launch_parameters.h>
#include<iostream>
#include<thrust/device_vector.h>
#include<thrust/sequence.h>
#ifndef __CUDACC__
#define __CUDACC__
#endif // !__CUDACC__
#include<mma.h>
using namespace nvcuda;
#define uint unsigned int
#define coreSizeM 16
#define coreSizeN 16
#define coreSizeK 16
__global__ void TensorCoreMM(half* a, half* b, float* c,
const int lm, const int ln, const int lk)
{
const uint x = (blockDim.x * blockIdx.x + threadIdx.x) / 32;
const uint y = blockDim.y * blockIdx.y + threadIdx.y;
const uint la = lk, lb = ln, lc = ln;
const uint aRow = x * coreSizeM; // 当前tile左上角在A上的行数
const uint bCol = y * coreSizeN; // 当前tile左上角在B上的列数
if (aRow >= lm || bCol >= ln) return;
// 声明fragment
wmma::fragment<wmma::matrix_a, coreSizeM, coreSizeN, coreSizeK, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, coreSizeM, coreSizeN, coreSizeK, half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, coreSizeM, coreSizeN, coreSizeK, float> c_frag;
// 清理c_frag
wmma::fill_fragment(c_frag, 0.f);
for (int i = 0; i < la; i += coreSizeK)
{
const uint aCol = i;
const uint bRow = i;
// load
wmma::load_matrix_sync(a_frag, a + aCol + aRow * la, la);
wmma::load_matrix_sync(b_frag, b + bCol + bRow * lb, lb);
// multiple and accumulate
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
// store
wmma::store_matrix_sync(c + bCol + aRow * lc, c_frag, lc, wmma::mem_row_major);
}
#define vectorPtr(x) thrust::raw_pointer_cast(x.data())
int main()
{
// C = A * B + C
size_t M = 32, N = 16, K = 16;
thrust::host_vector<float> A_float(M * K);
thrust::sequence(A_float.begin(), A_float.end());
thrust::device_vector<half> A(A_float.begin(), A_float.end());
thrust::device_vector<half> B(K * N, 1.f);
thrust::device_vector<float> C(M * N, 0.f);
dim3 blockSize(128, 4);
dim3 gridSize((M + blockSize.x - 1) / blockSize.x,
(N + blockSize.y - 1) / blockSize.y);
TensorCoreMM<<<gridSize, blockSize>>>(vectorPtr(A), vectorPtr(B), vectorPtr(C), M, N, K);
for (int i = 0; i < M; ++i)
{
thrust::copy(C.begin() + i * N, C.begin() + (i + 1) * N, std::ostream_iterator<float>(std::cout, ", "));
std::cout << std::endl;
}
return 0;
}