A的形状是M K, B 的形状是 K N,C 的形状是 M * N。大致要进行 2KMN 次浮点数运算。我们关注每秒浮点计算数 FLOPS 。为了便于计算,令 α = 1,β = 0,使用FP32。
1 2 3 4 5 6 7 8 9 10 11 12 13
// cpu 实现 voidcpuSgemm(float *a, float *b, float *c, constint M, constint N, constint K) { for (int m = 0; m < M; ++m) { for (int n = 0; n < N; ++n) { float psum = 0.0; for (int k = 0; k < K; ++k) psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)]; c[OFFSET(m, n, N)] = psum; } } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// 朴素的 GPU 实现 //__restrict__ 的作用:开发者向编译器保证:在该指针的作用域内,所有通过该指针访问的内存不会通过其他指针或引用被修改。编译器可以据此假设无别名冲突,从而生成更高效的代码。 __global__ voidnaiveSgemm(float* __restrict__ a, float* __restrict__ b, float* __restrict__ c, constint M, constint K, constint N) { int tidx = blockIdx.x * blockDim.x + threadIdx.x; int tidy = blockIdx.y * blockDim.y + threadIdx.y;
if(tidx < N && tidy < M) { float sum = 0.0; #pragma unroll for(int k = 0; k < K; ++k) sum += a[OFFSET(tidy, k, K)] * b[OFFSET(k, tidx, N)]; c[OFFSET(tidx, tidy, N)] = sum; } }
M N K = 1281281024, Time = 0.000084990.000089080.00009715 s, AVG Performance = 376.6978 Gflops M N K = 1921921024, Time = 0.000089090.000090930.00009216 s, AVG Performance = 830.2995 Gflops M N K = 2562561024, Time = 0.000089090.000089700.00009114 s, AVG Performance = 1496.2557 Gflops M N K = 3843841024, Time = 0.000164860.000172850.00017510 s, AVG Performance = 1747.1090 Gflops M N K = 5125121024, Time = 0.000319490.000321220.00032973 s, AVG Performance = 1671.3371 Gflops M N K = 7687681024, Time = 0.000638980.000652080.00065434 s, AVG Performance = 1852.4714 Gflops M N K = 102410241024, Time = 0.001061890.001063940.00106598 s, AVG Performance = 2018.4210 Gflops M N K = 153615361024, Time = 0.002818050.002829100.00284774 s, AVG Performance = 1707.9060 Gflops M N K = 204820481024, Time = 0.005600260.005607220.00561152 s, AVG Performance = 1531.9420 Gflops M N K = 307230721024, Time = 0.012168190.013217580.01334579 s, AVG Performance = 1462.2455 Gflops M N K = 409640961024, Time = 0.020876290.020899120.02097050 s, AVG Performance = 1644.0758 Gflops M N K = 614461441024, Time = 0.049269760.049491250.04952781 s, AVG Performance = 1562.0824 Gflops M N K = 819281921024, Time = 0.085971970.085996750.08600986 s, AVG Performance = 1598.1878 Gflops M N K = 12288122881024, Time = 0.192438270.194097760.19477504 s, AVG Performance = 1593.2056 Gflops M N K = 16384163841024, Time = 0.343011320.343133690.34316391 s, AVG Performance = 1602.1622 Gflops
GPU 对 Global Memory(全局内存)的访问通常以 32-byte 或 128-byte 为基本单元,某些情况下(如连续对齐访问),可能合并为更大的 128-byte 事务以提高效率。
// 加载 到 s_a 的 m 行 int load_a_smem_m = tid >> 1; int load_a_smem_k = (tid & 1) << 2; int load_b_smem_k = tid >> 5; int load_b_smem_n = (tid & 31) << 2;
// 从 g_a 的 m 行 加载 int load_a_gmem_m = by * BM + load_a_smem_m; int load_b_gmem_n = bx * BN + load_b_smem_n;
int cnt = (K + BK - 1) / BK; for(int bk = 0; i < cnt; ++bk) { int load_a_gmem_k = bk * BK + load_a_smem_k; int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K); FLOAT4(s_a[load_a_smem_m][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr]); int load_b_gmem_k = bk * BK + load_b_smem_k; int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N); FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr]);
__syncthreads();
#pragma unroll for(int k = 0; k < BK; ++k) { #pragma unroll for(int m = 0; m < TM; ++m) { #pragma unroll for(int n = 0; n < TN; ++n) { int comp_a_smem_m = ty * TM + m; int comp_b_smem_n = tx * TN + n; r_c[m][n] += s_a[comp_a_smem_m][k] * s_b[k][comp_b_smem_n]; } } }
__syncthreads(); }
#pragma unroll for(int i = 0; i < TM; ++i) { int store_c_gmem_m = by * BM + ty * TM + i; #pragma unroll for(int j = 0; j < TN; j += 4) { int store_c_gmem_n = bx * BN + tx * TN + j; int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N); FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][j]); } } }
计算结果如下,性能达到了理论峰值的54.5%:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
M N K = 1281281024, Time = 0.000191490.000197690.00020950 s, AVG Performance = 169.7357 Gflops M N K = 1921921024, Time = 0.000193540.000197110.00019968 s, AVG Performance = 383.0212 Gflops M N K = 2562561024, Time = 0.000198660.000199270.00019968 s, AVG Performance = 673.5457 Gflops M N K = 3843841024, Time = 0.000197630.000204290.00020685 s, AVG Performance = 1478.2557 Gflops M N K = 5125121024, Time = 0.000204800.000205310.00020582 s, AVG Performance = 2614.9028 Gflops M N K = 7687681024, Time = 0.000205790.000206950.00020787 s, AVG Performance = 5837.0423 Gflops M N K = 102410241024, Time = 0.000207870.000209200.00020992 s, AVG Performance = 10265.0610 Gflops M N K = 153615361024, Time = 0.000344060.000344470.00034509 s, AVG Performance = 14026.7298 Gflops M N K = 204820481024, Time = 0.000658430.000659960.00066150 s, AVG Performance = 13015.7464 Gflops M N K = 307230721024, Time = 0.001297410.001301090.00130253 s, AVG Performance = 14854.6887 Gflops M N K = 409640961024, Time = 0.002087940.002089780.00209408 s, AVG Performance = 16441.8282 Gflops M N K = 614461441024, Time = 0.004612100.004618850.00462234 s, AVG Performance = 16737.7892 Gflops M N K = 819281921024, Time = 0.006866940.007075640.00795443 s, AVG Performance = 19424.2566 Gflops M N K = 12288122881024, Time = 0.015727650.015901180.01641571 s, AVG Performance = 19447.4602 Gflops M N K = 16384163841024, Time = 0.027860990.029431300.03024384 s, AVG Performance = 18679.2936 Gflops
如果某些threads要取的数,来自同一个bank,并且是这个bank上的不同数(同一个bank的不同地址,即不同layer),此时发生了bank conflict。同个warp内的threads想要访问同一个bank下的n个不同的地址,就发生了n-way bank conflict(n头bank conflict)。本该1次指令取回的数,就需要串行发送n次指令。
int load_a_smem_m = tid >> 1; int load_a_smem_k = (tid & 1) << 2; int load_b_smem_k = tid >> 5; int load_b_smem_n = (tid & 31) << 2;
int load_a_gmem_m = by * BM + load_a_smem_m; int load_b_gmem_n = bx * BN + load_b_smem_n;
int cnt = (K + BK - 1) / BK; for(int bk = 0; bk < cnt; ++bk) { int load_a_gmem_k = bk * BK + load_a_smem_k; int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K); int load_b_gmem_k = bk * BK + load_b_smem_k; int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N); FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]); FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
#pragma unroll for(int m = 0; m < TM; ++m) { #pragma unroll for(int n = 0; n < TN; ++n) r_c[m][n] += r_comp_a[m] * r_comp_b[n]; } }
__syncthreads(); }
#pragma unroll for(int i = 0; i < TM / 2; ++i) { int store_c_gmem_m = by * BM + ty * TM / 2 + i; int store_c_gmem_n = bx * BN + tx * TN / 2; int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N); FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]); FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]); } #pragma unroll for(int i = 0; i < TM / 2; ++i) { int store_c_gmem_m = by * BM + ty * TM / 2 + i + BM / 2; int store_c_gmem_n = bx * BN + tx * TN / 2; int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N); FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]); FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]); }
}
计算结果如下,性能达到了理论峰值的68.6%:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
M N K = 1281281024, Time = 0.000173060.000176030.00017715 s, AVG Performance = 190.6225 Gflops M N K = 1921921024, Time = 0.000144380.000152470.00016486 s, AVG Performance = 495.1511 Gflops M N K = 2562561024, Time = 0.000144380.000144790.00014541 s, AVG Performance = 926.9590 Gflops M N K = 3843841024, Time = 0.000145410.000156770.00016179 s, AVG Performance = 1926.3098 Gflops M N K = 5125121024, Time = 0.000152580.000154210.00015872 s, AVG Performance = 3481.3280 Gflops M N K = 7687681024, Time = 0.000154620.000159540.00016077 s, AVG Performance = 7571.5533 Gflops M N K = 102410241024, Time = 0.000157700.000160250.00016278 s, AVG Performance = 13400.6000 Gflops M N K = 153615361024, Time = 0.000241660.000243800.00024678 s, AVG Performance = 19819.2506 Gflops M N K = 204820481024, Time = 0.000468990.000471440.00047411 s, AVG Performance = 18220.6316 Gflops M N K = 307230721024, Time = 0.000924670.000930710.00093798 s, AVG Performance = 20766.1658 Gflops M N K = 409640961024, Time = 0.001494020.001503130.00152064 s, AVG Performance = 22858.7997 Gflops M N K = 614461441024, Time = 0.003252220.003277110.00329728 s, AVG Performance = 23590.7485 Gflops M N K = 819281921024, Time = 0.005683200.005706960.00574771 s, AVG Performance = 24082.7044 Gflops M N K = 12288122881024, Time = 0.012493820.012632680.01291366 s, AVG Performance = 24479.1828 Gflops M N K = 16384163841024, Time = 0.022263810.024187700.02583859 s, AVG Performance = 22728.7352 Gflops
int cnt = (K + BK - 1) / BK; for(int bk = 1; bk < cnt; ++bk) { int smem_sel = (bk - 1) & 1; int smem_sel_next = bk & 1;
int load_a_gmem_k = bk * BK + load_a_smem_k; int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K); int load_b_gmem_k = bk * BK + load_b_smem_k; int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N); FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]); FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
#pragma unroll for(int m = 0; m < TM; ++m) { #pragma unroll for(int n = 0; n < TN; ++n) r_c[m][n] += r_comp_a[m] * r_comp_b[n]; } }
#pragma unroll for(int i = 0; i < TM / 2; ++i) { int store_c_gmem_m = by * BM + ty * TM / 2 + i; int store_c_gmem_n = bx * BN + tx * TN / 2; int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N); FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]); FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]); }
#pragma unroll for(int i = 0; i < TM / 2; ++i) { int store_c_gmem_m = by * BM + ty * TM / 2 + i + BM / 2; int store_c_gmem_n = bx * BN + tx * TN / 2; int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N); FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]); FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]); }
}
计算结果如下,性能达到了理论峰值的75.7%:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
M N K = 1281281024, Time = 0.000111620.000117250.00012186 s, AVG Performance = 286.1834 Gflops M N K = 1921921024, Time = 0.000110590.000111920.00011264 s, AVG Performance = 674.5471 Gflops M N K = 2562561024, Time = 0.000111550.000112430.00011469 s, AVG Performance = 1193.8020 Gflops M N K = 3843841024, Time = 0.000109570.000111620.00011366 s, AVG Performance = 2705.6147 Gflops M N K = 5125121024, Time = 0.000108540.000109570.00011469 s, AVG Performance = 4899.8878 Gflops M N K = 7687681024, Time = 0.000109570.000110380.00011162 s, AVG Performance = 10943.2485 Gflops M N K = 102410241024, Time = 0.000109570.000111320.00011776 s, AVG Performance = 19291.3629 Gflops M N K = 153615361024, Time = 0.000212990.000219440.00022221 s, AVG Performance = 22019.2705 Gflops M N K = 204820481024, Time = 0.000437150.000444310.00045056 s, AVG Performance = 19333.1840 Gflops M N K = 307230721024, Time = 0.000906240.000911660.00091648 s, AVG Performance = 21200.2324 Gflops M N K = 409640961024, Time = 0.001490940.001519100.00167629 s, AVG Performance = 22618.4236 Gflops M N K = 614461441024, Time = 0.003342340.003422940.00378880 s, AVG Performance = 22585.6909 Gflops M N K = 819281921024, Time = 0.005879810.006189160.00649728 s, AVG Performance = 22206.4044 Gflops M N K = 12288122881024, Time = 0.011355140.011446680.01189990 s, AVG Performance = 27015.4839 Gflops M N K = 16384163841024, Time = 0.020170750.022076420.02387558 s, AVG Performance = 24902.4032 Gflops
4. cuBLAS 性能
性能达到了理论峰值的83.9%:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
M N K = 1281281024, Time = 0.000018430.000355050.00335494 s, AVG Performance = 94.5063 Gflops M N K = 1921921024, Time = 0.000022530.000266750.00245965 s, AVG Performance = 283.0250 Gflops M N K = 2562561024, Time = 0.000026620.000033890.00008397 s, AVG Performance = 3959.8791 Gflops M N K = 3843841024, Time = 0.000031740.000038400.00005530 s, AVG Performance = 7864.9753 Gflops M N K = 5125121024, Time = 0.000051200.000053860.00006554 s, AVG Performance = 9967.4525 Gflops M N K = 7687681024, Time = 0.000086020.000091750.00013824 s, AVG Performance = 13165.7144 Gflops M N K = 102410241024, Time = 0.000126980.000130350.00014438 s, AVG Performance = 16474.4968 Gflops M N K = 153615361024, Time = 0.000222110.000235210.00034304 s, AVG Performance = 20542.9712 Gflops M N K = 204820481024, Time = 0.000424960.000427920.00044544 s, AVG Performance = 20073.5427 Gflops M N K = 307230721024, Time = 0.000849920.000855350.00088269 s, AVG Performance = 22595.9154 Gflops M N K = 409640961024, Time = 0.001461250.001468720.00149606 s, AVG Performance = 23394.2915 Gflops M N K = 614461441024, Time = 0.003066820.003071890.00308838 s, AVG Performance = 25166.7152 Gflops M N K = 819281921024, Time = 0.004578300.004705080.00521523 s, AVG Performance = 29210.7880 Gflops M N K = 12288122881024, Time = 0.010227710.010325090.01057178 s, AVG Performance = 29950.1034 Gflops M N K = 16384163841024, Time = 0.018221060.018953420.01967821 s, AVG Performance = 29005.6252 Gflops