cuda_learning_03

cuda_learning_03

一、GPU的内存体系

二、通过归约(Reduction)操作理解GPU内存体系

1. 问题

  $C = αAB + βC$

2. 基本特征

  A的形状是M K, B 的形状是 K N,C 的形状是 M * N。大致要进行 2KMN 次浮点数运算。我们关注每秒浮点计算数 FLOPS 。为了便于计算,令 α = 1,β = 0,使用FP32。

1
2
3
4
5
6
7
8
9
10
11
12
13
// cpu 实现
void cpuSgemm(float *a, float *b, float *c, const int M, const int N, const int K)
{
for (int m = 0; m < M; ++m)
{
for (int n = 0; n < N; ++n)
{
float psum = 0.0;
for (int k = 0; k < K; ++k) psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
c[OFFSET(m, n, N)] = psum;
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 朴素的 GPU 实现
//__restrict__ 的作用:开发者向编译器保证:在该指针的作用域内,所有通过该指针访问的内存不会通过其他指针或引用被修改。编译器可以据此假设无别名冲突,从而生成更高效的代码。
__global__ void naiveSgemm(float* __restrict__ a, float* __restrict__ b, float* __restrict__ c, const int M, const int K, const int N)
{
int tidx = blockIdx.x * blockDim.x + threadIdx.x;
int tidy = blockIdx.y * blockDim.y + threadIdx.y;

if(tidx < N && tidy < M)
{
float sum = 0.0;
#pragma unroll
for(int k = 0; k < K; ++k) sum += a[OFFSET(tidy, k, K)] * b[OFFSET(k, tidx, N)];
c[OFFSET(tidx, tidy, N)] = sum;
}
}

  NVIDIA GeForce RTX 3090的执行结果如下,理论算力$FP32 (TFLOPS)=CUDA 核心数×加速频率×2$,计算得10,496×1.70×2=35.7 TFLOPS。该方法算力利用率仅有4.5%。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
M N K =    128    128   1024, Time =   0.00008499   0.00008908   0.00009715 s, AVG Performance =   376.6978 Gflops
M N K = 192 192 1024, Time = 0.00008909 0.00009093 0.00009216 s, AVG Performance = 830.2995 Gflops
M N K = 256 256 1024, Time = 0.00008909 0.00008970 0.00009114 s, AVG Performance = 1496.2557 Gflops
M N K = 384 384 1024, Time = 0.00016486 0.00017285 0.00017510 s, AVG Performance = 1747.1090 Gflops
M N K = 512 512 1024, Time = 0.00031949 0.00032122 0.00032973 s, AVG Performance = 1671.3371 Gflops
M N K = 768 768 1024, Time = 0.00063898 0.00065208 0.00065434 s, AVG Performance = 1852.4714 Gflops
M N K = 1024 1024 1024, Time = 0.00106189 0.00106394 0.00106598 s, AVG Performance = 2018.4210 Gflops
M N K = 1536 1536 1024, Time = 0.00281805 0.00282910 0.00284774 s, AVG Performance = 1707.9060 Gflops
M N K = 2048 2048 1024, Time = 0.00560026 0.00560722 0.00561152 s, AVG Performance = 1531.9420 Gflops
M N K = 3072 3072 1024, Time = 0.01216819 0.01321758 0.01334579 s, AVG Performance = 1462.2455 Gflops
M N K = 4096 4096 1024, Time = 0.02087629 0.02089912 0.02097050 s, AVG Performance = 1644.0758 Gflops
M N K = 6144 6144 1024, Time = 0.04926976 0.04949125 0.04952781 s, AVG Performance = 1562.0824 Gflops
M N K = 8192 8192 1024, Time = 0.08597197 0.08599675 0.08600986 s, AVG Performance = 1598.1878 Gflops
M N K = 12288 12288 1024, Time = 0.19243827 0.19409776 0.19477504 s, AVG Performance = 1593.2056 Gflops
M N K = 16384 16384 1024, Time = 0.34301132 0.34313369 0.34316391 s, AVG Performance = 1602.1622 Gflops

  GPU 对 Global Memory(全局内存)的访问通常以 32-byte128-byte 为基本单元,某些情况下(如连续对齐访问),可能合并为更大的 128-byte 事务以提高效率。

  对于 1 个 warp 中的 32 个 thread, 在每 1 次循环中, 需要读取矩阵 A 同一个元素 (1 次 transaction),以及矩阵 B 连续的 32 个元素 (假设是理想的可合并访问的,以32-byte为单位, 至少需要 4 次 transaction),共发生 5 次 transaction。K 次循环总共需要 k×5 次 transactions。 对于 M×N 个 thread, 共有 M×N/32 个 warp,总共的 Global Memory Load Transaction 数目为: M×N/32×K×5 (注意, 并不是前文的 K×M×N×2 次)。

  计算访存比 2KMN/(KMN/32×5×4)=3.2OP/byte ,由于实测带宽为 763GB/s (官方文档为900GB/s),由此可以得到这种方式下理论算力最高可达到 64/20∗763=2442 TFLOPS。

$理论算力=带宽×计算访存比$

3. 优化

3.1 矩阵分块并利用Shared Memory和Registers

image-20250507210611016

  具体矩阵分块如上图(图片来源)。下面详细分析一下线程块等具体的大小选择。

  对于每一个分块:

  计算量:$BM \times BN \times K \times 2$

  访存量:$(BM + BN) K 4 Byte$

  计算访存比:$\frac{BM \times BN}{2(BM + BN)} = \frac{1}{2(\frac{1}{BN} + \frac{1}{BM})}$

  由上式可知BM和BN越大,计算访存比越高,性能就会越好。但是由于 Shared Memory 容量的限制(V100 1个SM仅96KB),而一个Block需要占用 BK (BM + BN) 4 Bytes大小。

  TM和TN的取值也受到两方面限制,一方面是线程数的限制,一个Block中有BM / TM BN / TN个线程,这个数字不能超过1024,且不能太高防止影响SM内Block间的并行;另一方面是寄存器数目的限制,一个线程至少需要TM TN个寄存器用于存放矩阵 C 的部分和,再加上一些其它的寄存器,所有的寄存器数目不能超过256,且不能太高防止影响SM内同时并行的线程数目。

  最终选取 BM = BN = 128,BK = 8,TM = TN = 8,则此时计算访存比为32。根据V100的理论算力15.7TFLOPS,可得 15.7TFLOPS/32 = 490GB/s,根据实测的HBM带宽为763GB/s,可知此时带宽不再会限制计算性能。

  关于为什么沿着k维度切成更小的bk,而不是A沿m切,B沿n切?

  因为这样切可以保证每个小A分块和小B分块只会加载一次。

  根据以上分析,kernel函数实现主要包括以下步骤:

  • 从Global Memory加载对应的矩阵分块到Shared Memory中。

      每个线程刚好负责A矩阵的4个元素和B矩阵的4个元素,刚好可以用FLOAT4来操作,分配好每个线程负责的元素即可。

  • 计算对应的C矩阵

  • 写回Global Memory

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// 矩阵分块、Shared Memory、Registers
__global__ void sgemm_V1(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c, const int M, const int K, const int N)
{
const int BM = 128;
const int BN = 128;
const int BK = 8;
const int TM = 8;
const int TN = 8;

const int bx = blockIdx.x;
const int by = blockIdx.y;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;

__shared__ float s_a[BM][BK];
__shared__ float s_b[BK][BN];

float r_c[TM][TN] = {0.0};

// 加载 到 s_a 的 m 行
int load_a_smem_m = tid >> 1;
int load_a_smem_k = (tid & 1) << 2;
int load_b_smem_k = tid >> 5;
int load_b_smem_n = (tid & 31) << 2;

// 从 g_a 的 m 行 加载
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;

int cnt = (K + BK - 1) / BK;
for(int bk = 0; i < cnt; ++bk)
{
int load_a_gmem_k = bk * BK + load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
FLOAT4(s_a[load_a_smem_m][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr]);
int load_b_gmem_k = bk * BK + load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr]);

__syncthreads();

#pragma unroll
for(int k = 0; k < BK; ++k)
{
#pragma unroll
for(int m = 0; m < TM; ++m)
{
#pragma unroll
for(int n = 0; n < TN; ++n)
{
int comp_a_smem_m = ty * TM + m;
int comp_b_smem_n = tx * TN + n;
r_c[m][n] += s_a[comp_a_smem_m][k] * s_b[k][comp_b_smem_n];
}
}
}

__syncthreads();
}

#pragma unroll
for(int i = 0; i < TM; ++i)
{
int store_c_gmem_m = by * BM + ty * TM + i;
#pragma unroll
for(int j = 0; j < TN; j += 4)
{
int store_c_gmem_n = bx * BN + tx * TN + j;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][j]);
}
}
}

  计算结果如下,性能达到了理论峰值的54.5%:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
M N K =    128    128   1024, Time =   0.00019149   0.00019769   0.00020950 s, AVG Performance =   169.7357 Gflops
M N K = 192 192 1024, Time = 0.00019354 0.00019711 0.00019968 s, AVG Performance = 383.0212 Gflops
M N K = 256 256 1024, Time = 0.00019866 0.00019927 0.00019968 s, AVG Performance = 673.5457 Gflops
M N K = 384 384 1024, Time = 0.00019763 0.00020429 0.00020685 s, AVG Performance = 1478.2557 Gflops
M N K = 512 512 1024, Time = 0.00020480 0.00020531 0.00020582 s, AVG Performance = 2614.9028 Gflops
M N K = 768 768 1024, Time = 0.00020579 0.00020695 0.00020787 s, AVG Performance = 5837.0423 Gflops
M N K = 1024 1024 1024, Time = 0.00020787 0.00020920 0.00020992 s, AVG Performance = 10265.0610 Gflops
M N K = 1536 1536 1024, Time = 0.00034406 0.00034447 0.00034509 s, AVG Performance = 14026.7298 Gflops
M N K = 2048 2048 1024, Time = 0.00065843 0.00065996 0.00066150 s, AVG Performance = 13015.7464 Gflops
M N K = 3072 3072 1024, Time = 0.00129741 0.00130109 0.00130253 s, AVG Performance = 14854.6887 Gflops
M N K = 4096 4096 1024, Time = 0.00208794 0.00208978 0.00209408 s, AVG Performance = 16441.8282 Gflops
M N K = 6144 6144 1024, Time = 0.00461210 0.00461885 0.00462234 s, AVG Performance = 16737.7892 Gflops
M N K = 8192 8192 1024, Time = 0.00686694 0.00707564 0.00795443 s, AVG Performance = 19424.2566 Gflops
M N K = 12288 12288 1024, Time = 0.01572765 0.01590118 0.01641571 s, AVG Performance = 19447.4602 Gflops
M N K = 16384 16384 1024, Time = 0.02786099 0.02943130 0.03024384 s, AVG Performance = 18679.2936 Gflops

3.2 解决 Bank Conflict 问题

(1) LDS.32

  假设一个warp现在被调度了,它的32个thread此刻要去SMEM上读数。warp发送了一个LDS.32的指令(意思是让所有的thread都取1个数,其大小为4byte,换算成bit就是32)。此时,在cuda的运算中有如下规定:

  • 一个warp发送1次取数指令(不一定是LDS.32),它最多只能从SMEM上读取128bytes(32个数)的数据。这是每个warp发送1次取数指令的能取到的数据量上限
  • 如果每个thread要取的数,来自不同的bank,我们就认为没有bank conflict。warp发送1次指令,所有的threads即可取回自己想要的数据。
  • 来自同一个bank,但是是这个bank上的同一个数(同一个bank的相同地址,或者说是相同layer),此时也没有bank conflict(广播机制),也是1次指令。
  • 如果某些threads要取的数,来自同一个bank,并且是这个bank上的不同数(同一个bank的不同地址,即不同layer),此时发生了bank conflict。同个warp内的threads想要访问同一个bank下的n个不同的地址,就发生了n-way bank conflict(n头bank conflict)。本该1次指令取回的数,就需要串行发送n次指令。

(2)为什么要对SMEM做bank划分

  简单来说,为了均衡banks的路宽,为了warp间尽量并行,不要相互阻碍。

(3)LDS.64和LDS.128

  LDS.64指令:一次取8bytes的数,即连续的2个数。

  LDS.128指令:一次取16bytes的数,即连续的2个数。

  以LDS.128为例,一个warp需要取128个数,超过了warp单次memory transaction允许的取数上限。所以该warp会把取数过程拆成4个串行的phase(即4次串行的memory transcation):即0~7,8~15,16~23,24~31。这时bank conflict被定义在每个phase(也就是1/4个warp之内)

  接下来,我们分析v1中是否存在bank冲突。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for(int k = 0; k < BK; ++k)
{
#pragma unroll
for(int m = 0; m < TM; ++m)
{
#pragma unroll
for(int n = 0; n < TN; ++n)
{
int comp_a_smem_m = ty * TM + m;
int comp_b_smem_n = tx * TN + n;
r_c[m][n] += s_a[comp_a_smem_m][k] * s_b[k][comp_b_smem_n];
}
}
}

  一个简单且不太严谨的对s_b的访问分析,每个线程需要访问s_b的连续8个元素。那么tid = 0的线程,访问bank_id0,1,2,3,4,5,6,7tid = 1的线程,访问bank_id8,9,10,11,12,13,14,15;这样的话,就会发现tid = 0tid = 4的线程访问的bank正好相同,所以存在一个8路bank conflict。

  如果使用FLOAT4,你会发现对s_b的访问存在2路bank conflict,解决方法是把每个线程负责的计算的元素修改一下,让每个线程读取两次连续的4个数,而不是连续的8个数。以上面的图片为例,V1中每个线程块负责计算128*128的子矩阵,每个线程负责计算8*8的子矩阵,线程块大小是16*16。V2中每个线程负责4个4*4的子矩阵,即把128*128的子矩阵分成4个64*64的子矩阵(左上、右上、左下、右下),每个线程负责的那4个4*4子矩阵在这4个64*64的子矩阵的对应位置。

  对于s_a的访问,v1中不存在bank conflict,因为触发了广播机制。但是v1中的访问方式无法使用FLOAT4,因为对s_a的访问是不连续的,解决办法是转置。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
// 解决 bank conflict
__global__ void sgemm_V2(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c, const int M, const int N, const int K)
{
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;

__shared__ float s_a[BK][BM];
__shared__ float s_b[BK][BN];

float r_c[TN][TN] = {0.0};
float r_load_a[4];
float r_load_b[4];
float r_comp_a[TM];
float r_comp_b[TN];

int load_a_smem_m = tid >> 1;
int load_a_smem_k = (tid & 1) << 2;
int load_b_smem_k = tid >> 5;
int load_b_smem_n = (tid & 31) << 2;

int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;

int cnt = (K + BK - 1) / BK;
for(int bk = 0; bk < cnt; ++bk)
{
int load_a_gmem_k = bk * BK + load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = bk * BK + load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

s_a[load_a_smem_k][load_a_smem_m] = r_load_a[0];
s_a[load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);

__syncthreads();

#pragma unroll
for(int k = 0; k < BK; ++k)
{
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[k][ty * TM / 2]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[k][ty * TM / 2 + BM / 2]);
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[k][tx * TN / 2]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[k][tx * TN / 2 + BN / 2]);

#pragma unroll
for(int m = 0; m < TM; ++m)
{
#pragma unroll
for(int n = 0; n < TN; ++n) r_c[m][n] += r_comp_a[m] * r_comp_b[n];
}
}

__syncthreads();
}

#pragma unroll
for(int i = 0; i < TM / 2; ++i)
{
int store_c_gmem_m = by * BM + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
}
#pragma unroll
for(int i = 0; i < TM / 2; ++i)
{
int store_c_gmem_m = by * BM + ty * TM / 2 + i + BM / 2;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
}

}

  计算结果如下,性能达到了理论峰值的68.6%:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
M N K =    128    128   1024, Time =   0.00017306   0.00017603   0.00017715 s, AVG Performance =   190.6225 Gflops
M N K = 192 192 1024, Time = 0.00014438 0.00015247 0.00016486 s, AVG Performance = 495.1511 Gflops
M N K = 256 256 1024, Time = 0.00014438 0.00014479 0.00014541 s, AVG Performance = 926.9590 Gflops
M N K = 384 384 1024, Time = 0.00014541 0.00015677 0.00016179 s, AVG Performance = 1926.3098 Gflops
M N K = 512 512 1024, Time = 0.00015258 0.00015421 0.00015872 s, AVG Performance = 3481.3280 Gflops
M N K = 768 768 1024, Time = 0.00015462 0.00015954 0.00016077 s, AVG Performance = 7571.5533 Gflops
M N K = 1024 1024 1024, Time = 0.00015770 0.00016025 0.00016278 s, AVG Performance = 13400.6000 Gflops
M N K = 1536 1536 1024, Time = 0.00024166 0.00024380 0.00024678 s, AVG Performance = 19819.2506 Gflops
M N K = 2048 2048 1024, Time = 0.00046899 0.00047144 0.00047411 s, AVG Performance = 18220.6316 Gflops
M N K = 3072 3072 1024, Time = 0.00092467 0.00093071 0.00093798 s, AVG Performance = 20766.1658 Gflops
M N K = 4096 4096 1024, Time = 0.00149402 0.00150313 0.00152064 s, AVG Performance = 22858.7997 Gflops
M N K = 6144 6144 1024, Time = 0.00325222 0.00327711 0.00329728 s, AVG Performance = 23590.7485 Gflops
M N K = 8192 8192 1024, Time = 0.00568320 0.00570696 0.00574771 s, AVG Performance = 24082.7044 Gflops
M N K = 12288 12288 1024, Time = 0.01249382 0.01263268 0.01291366 s, AVG Performance = 24479.1828 Gflops
M N K = 16384 16384 1024, Time = 0.02226381 0.02418770 0.02583859 s, AVG Performance = 22728.7352 Gflops

3.3 流水并行化:Double Buffering

  之前的方法存在 访存-计算 的串行模式流水线,这个方法就是提高访存和计算的并行程度,下图很形象,图片来源

image-20250508191629076

  具体到代码实现中,主要有一下几个点:

  • 需要原来两倍的Shared Memory
  • 第一次加载数据在主循环之前,最后一次计算在主循环之后,主循环从 k = 1开始
  • 由于计算和下一次访存使用的Shared Memory不同,因此主循环中每次循环只需要一次__syncthreads()
  • GPU不能像CPU那样支持乱序执行,主循环中需要先将下一次循环计算需要的数据从Gloabal Memory中load 到寄存器,然后进行本次计算,之后再将load到寄存器中的数据写到Shared Memory,这样在LDG指令向Global Memory做load时,不会影响后续FFMA及其它运算指令的 launch 执行,也就达到了Double Buffering的目的。

  关于第四个点,还是挺深奥的,涉及GPU指令的流水线化和异步内存访问。一开始,我在想,既然GPU不能乱序执行,那么不还是串行吗?实则不然。答案如下,懒得自己总结了。

image-20250508194642841

  1. image-20250508194700041

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// 流水并行化:Double Buffering
__global__ void sgemm_V3(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c, const int M, const int N, const int K)
{
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;

__shared__ float s_a[2][BK][BM];
__shared__ float s_b[2][BK][BN];

float r_c[TM][TN] = {0.0};
float r_load_a[4];
float r_load_b[4];
float r_comp_a[TM];
float r_comp_b[TN];

int load_a_smem_m = tid >> 1;
int load_a_smem_k = (tid & 1) << 2;
int load_b_smem_k = tid >> 5;
int load_b_smem_n = (tid & 31) << 2;

int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;

{
int load_a_gmem_k = load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

s_a[0][load_a_smem_k][load_a_smem_m] = r_load_a[0];
s_a[0][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[0][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[0][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[0][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
}


int cnt = (K + BK - 1) / BK;
for(int bk = 1; bk < cnt; ++bk)
{
int smem_sel = (bk - 1) & 1;
int smem_sel_next = bk & 1;

int load_a_gmem_k = bk * BK + load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = bk * BK + load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

#pragma unroll
for(int k = 0; k < BK; ++k)
{
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[smem_sel][k][ty * TM / 2]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[smem_sel][k][ty * TM / 2 + BM / 2]);
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[smem_sel][k][tx * TN / 2]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[smem_sel][k][tx * TN / 2 + BN / 2]);

#pragma unroll
for(int m = 0; m < TM; ++m)
{
#pragma unroll
for(int n = 0; n < TN; ++n) r_c[m][n] += r_comp_a[m] * r_comp_b[n];
}
}

s_a[smem_sel_next][load_a_smem_k][load_a_smem_m] = r_load_a[0];
s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[smem_sel_next][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[smem_sel_next][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);

__syncthreads();
}

#pragma unroll
for(int k = 0; k < BK; ++k)
{
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[1][k][ty * TM / 2]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[1][k][ty * TM / 2 + BM / 2]);
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[1][k][tx * TN / 2]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[1][k][tx * TN / 2 + BN / 2]);

#pragma unroll
for(int m = 0; m < TM; ++m)
{
#pragma unroll
for(int n = 0; n < TN; ++n) r_c[m][n] += r_comp_a[m] * r_comp_b[n];
}
}

#pragma unroll
for(int i = 0; i < TM / 2; ++i)
{
int store_c_gmem_m = by * BM + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
}

#pragma unroll
for(int i = 0; i < TM / 2; ++i)
{
int store_c_gmem_m = by * BM + ty * TM / 2 + i + BM / 2;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
}

}

  计算结果如下,性能达到了理论峰值的75.7%:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
M N K =    128    128   1024, Time =   0.00011162   0.00011725   0.00012186 s, AVG Performance =   286.1834 Gflops
M N K = 192 192 1024, Time = 0.00011059 0.00011192 0.00011264 s, AVG Performance = 674.5471 Gflops
M N K = 256 256 1024, Time = 0.00011155 0.00011243 0.00011469 s, AVG Performance = 1193.8020 Gflops
M N K = 384 384 1024, Time = 0.00010957 0.00011162 0.00011366 s, AVG Performance = 2705.6147 Gflops
M N K = 512 512 1024, Time = 0.00010854 0.00010957 0.00011469 s, AVG Performance = 4899.8878 Gflops
M N K = 768 768 1024, Time = 0.00010957 0.00011038 0.00011162 s, AVG Performance = 10943.2485 Gflops
M N K = 1024 1024 1024, Time = 0.00010957 0.00011132 0.00011776 s, AVG Performance = 19291.3629 Gflops
M N K = 1536 1536 1024, Time = 0.00021299 0.00021944 0.00022221 s, AVG Performance = 22019.2705 Gflops
M N K = 2048 2048 1024, Time = 0.00043715 0.00044431 0.00045056 s, AVG Performance = 19333.1840 Gflops
M N K = 3072 3072 1024, Time = 0.00090624 0.00091166 0.00091648 s, AVG Performance = 21200.2324 Gflops
M N K = 4096 4096 1024, Time = 0.00149094 0.00151910 0.00167629 s, AVG Performance = 22618.4236 Gflops
M N K = 6144 6144 1024, Time = 0.00334234 0.00342294 0.00378880 s, AVG Performance = 22585.6909 Gflops
M N K = 8192 8192 1024, Time = 0.00587981 0.00618916 0.00649728 s, AVG Performance = 22206.4044 Gflops
M N K = 12288 12288 1024, Time = 0.01135514 0.01144668 0.01189990 s, AVG Performance = 27015.4839 Gflops
M N K = 16384 16384 1024, Time = 0.02017075 0.02207642 0.02387558 s, AVG Performance = 24902.4032 Gflops

4. cuBLAS 性能

  性能达到了理论峰值的83.9%:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
M N K =    128    128   1024, Time =   0.00001843   0.00035505   0.00335494 s, AVG Performance =    94.5063 Gflops
M N K = 192 192 1024, Time = 0.00002253 0.00026675 0.00245965 s, AVG Performance = 283.0250 Gflops
M N K = 256 256 1024, Time = 0.00002662 0.00003389 0.00008397 s, AVG Performance = 3959.8791 Gflops
M N K = 384 384 1024, Time = 0.00003174 0.00003840 0.00005530 s, AVG Performance = 7864.9753 Gflops
M N K = 512 512 1024, Time = 0.00005120 0.00005386 0.00006554 s, AVG Performance = 9967.4525 Gflops
M N K = 768 768 1024, Time = 0.00008602 0.00009175 0.00013824 s, AVG Performance = 13165.7144 Gflops
M N K = 1024 1024 1024, Time = 0.00012698 0.00013035 0.00014438 s, AVG Performance = 16474.4968 Gflops
M N K = 1536 1536 1024, Time = 0.00022211 0.00023521 0.00034304 s, AVG Performance = 20542.9712 Gflops
M N K = 2048 2048 1024, Time = 0.00042496 0.00042792 0.00044544 s, AVG Performance = 20073.5427 Gflops
M N K = 3072 3072 1024, Time = 0.00084992 0.00085535 0.00088269 s, AVG Performance = 22595.9154 Gflops
M N K = 4096 4096 1024, Time = 0.00146125 0.00146872 0.00149606 s, AVG Performance = 23394.2915 Gflops
M N K = 6144 6144 1024, Time = 0.00306682 0.00307189 0.00308838 s, AVG Performance = 25166.7152 Gflops
M N K = 8192 8192 1024, Time = 0.00457830 0.00470508 0.00521523 s, AVG Performance = 29210.7880 Gflops
M N K = 12288 12288 1024, Time = 0.01022771 0.01032509 0.01057178 s, AVG Performance = 29950.1034 Gflops
M N K = 16384 16384 1024, Time = 0.01822106 0.01895342 0.01967821 s, AVG Performance = 29005.6252 Gflops