欢迎光临散文网 会员登陆 & 注册

利用Tensor Core加速矩阵乘法的代码之理解

2022-08-26 21:38 作者:不会跑路的小向晚  | 我要投稿


Tensor Core的官方文档名字叫Programming Guide

Tensor Core单元可以实现矩阵乘法的加速,也为CUDA核提供了调用接口。我最近学习了相关文档,也读到官方给出的矩阵乘法样例代码,在此记录自己的经验和理解。

常见的矩阵乘法有两种:D = A * B + C 与 C = A * B + C,两者之间的区别是后者原地计算。

这里预先约定好矩阵形状的代表符号,A[M, K], B[K, N],其中M是A的行数,K是A的列数,N是B的列数,于是容易推算出D[M, N], C[M, N],这些符号需要记住。

在tensor core之前,cuda kernel核函数的矩阵运算如下图所示,将A在y方向上分割成M / m块[m, K]大小的矩阵,同样将B在x方向上分割成多块N / n块[n, K]大小的矩阵,于是我们只需要申请 M / m * N / n个线程块即可对A*B并行计算,具体实现来说,可以申请dim3 blockSize(n, m)的线程块与dim3 gridSize(N / n, M / m)的网格。

而对于每个线程块,在K方向上也可以划分成K / k个小矩阵,A中每个小矩阵的形状为[m, k],于是矩阵乘法变成了更小的子矩阵乘法,1*1 + 2*2 + 3*3...即可得到我们想要的结果,可以在kernel核函数中用循环实现这种计算。

这种矩阵乘法的实现比较常见,在cuda基础教程中有代码实现。

说回Tensor Core,其加速矩阵乘法与上述的思路类似,但我们需要先了解一下其硬件特性。与FP32 Core类似,Tensor Core就是一个运算单元,前者输入两个浮点数,返回一个浮点数加法结果,后者输入两个矩阵,返回矩阵乘法结果。在cuda C的tensor core接口(wmma)中,kernel核函数中一次tensor core的运算需要占用一个warp的线程(32个)。由于tensor core的一次运算的矩阵大小是固定的,所需线程数也是固定的,所以我们多个tensor core并行运算只需要对矩阵、线程进行分割即可,下面讲讲怎么分割。

假设tensor core的一次矩阵运算的形状为[m, k] * [k, n] = [m, n],其中从A矩阵中分割出[m, k]的子矩阵,从B矩阵分割出[k, n]的子矩阵,得到一个[m, n]的子矩阵。通过简单的计算可得,A矩阵要求在y方向上需要M / m个warp的线程(每个warp负责[m, k]的矩阵),B矩阵要求在x方向上需要N / n个warp的线程,而在kernel内进行K / k次的循环累加即可得到C中[m, n]的子矩阵。如果你熟悉之前的矩阵乘法,这一定不难想明白。

剩下的就是编程了:

首先预定义__CUDACC__,其实不做预定义也能编译成功,但VS不会出现代码的输入补全提示,而且满屏波浪号。

然后是初始化矩阵,这里A的内容是1, 2, 3, 4....512的序列,B的元素全是1,C的元素全是0,由于tensor core不接受float的输入,所以使用半精度half作为输入,float作为输出。

最后是定义tensor core接收矩阵形状的大小核函数了

总之,使用tensor core的矩阵乘法与普通的矩阵乘法其实是类似的,只不过tensor core的运算粒度更大,吞吐量更高。

完整代码如下:

#include<device_launch_parameters.h>

#include<iostream>

#include<thrust/device_vector.h>

#include<thrust/sequence.h>


#ifndef __CUDACC__

#define __CUDACC__

#endif // !__CUDACC__


#include<mma.h>


using namespace nvcuda;


#define uint unsigned int


#define coreSizeM 16

#define coreSizeN 16

#define coreSizeK 16

__global__ void TensorCoreMM(half* a, half* b, float* c,

const int lm, const int ln, const int lk)

{

const uint x = (blockDim.x * blockIdx.x + threadIdx.x) / 32;

const uint y = blockDim.y * blockIdx.y + threadIdx.y;


const uint la = lk, lb = ln, lc = ln;


const uint aRow = x * coreSizeM; // 当前tile左上角在A上的行数

const uint bCol = y * coreSizeN; // 当前tile左上角在B上的列数


if (aRow >= lm || bCol >= ln) return;


// 声明fragment

wmma::fragment<wmma::matrix_a, coreSizeM, coreSizeN, coreSizeK, half, wmma::row_major> a_frag;

wmma::fragment<wmma::matrix_b, coreSizeM, coreSizeN, coreSizeK, half, wmma::row_major> b_frag;

wmma::fragment<wmma::accumulator, coreSizeM, coreSizeN, coreSizeK, float> c_frag;


// 清理c_frag

wmma::fill_fragment(c_frag, 0.f);


for (int i = 0; i < la; i += coreSizeK)

{

const uint aCol = i;

const uint bRow = i;

// load

wmma::load_matrix_sync(a_frag, a + aCol + aRow * la, la);

wmma::load_matrix_sync(b_frag, b + bCol + bRow * lb, lb);


// multiple and accumulate

wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);

}


// store

wmma::store_matrix_sync(c + bCol + aRow * lc, c_frag, lc, wmma::mem_row_major);

}


#define vectorPtr(x) thrust::raw_pointer_cast(x.data())


int main()

{

// C = A * B + C

size_t M = 32, N = 16, K = 16;

thrust::host_vector<float> A_float(M * K);

thrust::sequence(A_float.begin(), A_float.end());

thrust::device_vector<half> A(A_float.begin(), A_float.end());

thrust::device_vector<half> B(K * N, 1.f);

thrust::device_vector<float> C(M * N, 0.f);


dim3 blockSize(128, 4);

dim3 gridSize((M + blockSize.x - 1) / blockSize.x, 

(N + blockSize.y - 1) / blockSize.y);

TensorCoreMM<<<gridSize, blockSize>>>(vectorPtr(A), vectorPtr(B), vectorPtr(C), M, N, K);


for (int i = 0; i < M; ++i)

{

thrust::copy(C.begin() + i * N, C.begin() + (i + 1) * N, std::ostream_iterator<float>(std::cout, ", "));

std::cout << std::endl;

}

return 0;

}


利用Tensor Core加速矩阵乘法的代码之理解的评论 (共 条)

分享到微博请遵守国家法律