cuda_learning_03

cuda_learning_03

一、GPU的内存体系

二、通过归约(Reduction)操作理解GPU内存体系

问题

  计算A * B = C

  [M, K] * [K, N] = [M, N]

naiveSgemm

  最朴素的gemm,只是利用了gpu多线程的并行而已。出处(二、naive gemm)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#define OFFSET(row, col, ld) ((row) * (ld) + (col))
//__restrict__表示“限制性指针”,告诉编译器a指向的内存区域不会与其他指针指向的内存重叠
//简而言之,编译优化用的
__global__ void naiveSgemm(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c, const int M, const int K, const int N)
{
int tidx = blockIdx.x * blockDim.x + threadIdx.x;
int tidy = blockIdx.y * blockDim.y + threadIdx.y;

if(tidx < N && tidy < M)
{
float sum = 0.0;
#pragma unroll
for(int k = 0; k < K; ++k) sum += a[OFFSET(tidy, k, K)] * b[OFFSET(k, tidx, N)];
c[OFFSET(tidy, tidx, N)] = sum;
}
}

smemSgemm(ver1)

  利用共享内存加速。出处(Shared Memory 优化)

img

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
const int BM = 32, BN = 32;
const int M = 512, N = 512, K = 512;
dim3 blockDim(BN, BM);
dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM);
const int block_x = 32;
const int block_y = 32;

__global__ void smemSgemm(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c, const int M, const int K, const int N)
{
int tidx = blockIdx.x * blockDim.x + threadIdx.x;
int tidy = blockIdx.y * blockDim.y + threadIdx.y;
float sum = 0.0f;

__shared__ float smema[block_y][block_x];
__shared__ float smemb[block_y][block_x];

int cnt = (K + block_x - 1) / block_x;
for(int i = 0; i < cnt; ++i)
{
smema[tidy][tidx] = a[OFFSET(tidy, i * block_y + tidx, K)];
smemb[tidy][tidx] = b[OFFSET(i * block_x + tidy, tidx, N)];

__syncthreads();

for(int j = 0; j < block_x; ++j) sum += smema[tidy][j] * smemb[j][tidx];

__syncthreads();
}

c[OFFSET(tidy, tidx, N)] = sum;
}

regSgemm(ver1)

  寄存器优化,出处(Register 优化)

  说是寄存器优化,实际主要通过 共享内存指令级并行 (ILP) 机制优化了矩阵乘法计算的性能。与上面相比,只是一个线程计算C中的两个元素而已。例如,上面例子中[block_x, block_y] = [32, 32], smem大小是[32, 32],这个例子中,线程块大小发生变化,改为[32, 16]smem的大小不变。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
const int block_x = 32;
const int block_y = 16;
__global__ void regSgemm(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c, const int M, const int K, const int N)
{
int tidx = blockIdx.x * blockDim.x + threadIdx.x;
int tidy = blockIdx.y * blockDim.y + threadIdx.y;
float val[2] = {0.0f};

__shared__ float smema[block_x][block_x];
__shared__ float smemb[block_x][block_x];

int cnt = (K + block_x - 1) / block_x;
for(int i = 0; i < cnt; ++i)
{
smema[threadIdx.y][threadIdx.x] = a[OFFSET(tidy, i * block_x + tidx, K)];
smema[threadIdx.y + 16][threadIdx.x] = a[OFFSET(tidy + 16, i * block_x + tidx, K)];

smemb[threadIdx.y][threadIdx.x] = b[OFFSET(i * block_x + tidy, tidx, N)];
smemb[threadIdx.y + 16][threadIdx.x] = b[OFFSET(i * block_x + tidy + 16, tidx, N)];

__syncthreads();

for(int j = 0; j < block_x; ++j)
{
val[0] += smema[threadIdx.y][j] * smemb[j][threadIdx.x];
val[1] += smema[threadIdx.y + 16][j] * smemb[j][threadIdx.x];
}

__syncthreads();
}

c[OFFSET(tidy, tidx, N)] = val[0];
c[OFFSET(tidy + 16, tidx, N)] = val[1];
}

regSgemm(ver2)

  寄存器优化,出处(GEMM再优化:从SMEM到register)

image-20250410214425574

image-20250410214458379

  分块的过程中,个人认为是先确定下BM和BN的,之后是BK,另外BK并不需要和TM或是TN相等。BK主要影响smem的大小,因为smem大小是BM BK。在确定了BM = BN = 128后,可以计算出一共需要(M/BM)(N /BN) 个线程块,之后是确定线程块的大小,一个线程块一共计算 128 128 个元素,线程块大小不得超过1024,所以 16 16 是个不错的选择,那么每个线程需要计算的元素TM TN = (128 / 16) (128 / 16)。

  关于每个线程的工作,由于是按照k维度划分,故一共需要循环 $ \lceil K / BK \rceil$ 次,对于每一次循环,首先需要把这次需要用到的数据,从全局内存加载到smem中,smem大小是BMBK,线程块大小是 16 16,每个线程负责加载BM BK / (16 16) 个元素,分配好索引即可。之后,计算线程负责的TM * TN个元素,沿着k维遍历(矩阵乘法),结果存储在register中。循环结束后,将存储在register中的最后结果写回全局内存。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#define OFFSET(row, col, ld) ((row) * (ld) + (col))

__global__ void smemSgemm(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c, const int M, const int K, const int N)
{
const int BM = 128;
const int BN = 128;
const int BK = 8;
const int TM = 8;
const int TN = 8;

const int bidx = blockIdx.x;
const int bidy = blockIdx.y;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;

__shared__ float s_a[BM][BK];
__shared__ float s_b[BK][BN];

float r_c[TM][TN];

// 每个线程每次从全局内存中取的那4个数,放在共享内存中的哪几个位置
// 具体地说,是放在s_a[load_a_smem_m][load_a_smem_k],s_b[load_b_smem_k][load_b_smem_n]
// s_a 是 128 * 8, s_b 是 8 * 128,每个线程取4个数
int load_a_smem_m = tid >> 1; // tid / 2
int load_a_smem_k = (tid & 1) << 2; // tid % 2 == 0 ? 0 : 4

int load_b_smem_k = tid >> 5; // tid / 32
int load_b_smem_n = (tid & 31) << 2; // (tid % 32) * 4

// 从全局内存的哪四个位置取?a_k, b_k 是和 load_a_smem_k,load_b_smem_k 对应的
// 只需要计算load_a_gmem_m, load_b_gmem_n
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_smem_n = bx * BN + load_b_smem_n;
int load_a_gmem_k, load_b_gmem_k;
int load_a_gmem_addr, load_b_gmem_addr;

// 沿 k 维度分割,需要计算多少次
int cnt = (K + BK - 1) / BK;
for(int i = 0; i < cnt; ++i)
{
// 从全局内存加载到 smem 中
load_a_gmem_k = i * BK + load_a_smem_k;
load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
FLOAT4(s_a[load_a_smem_m][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr]);

load_b_gmem_k = i * BK + load_b_smem_k;
load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr]);

__syncthreads();

// 计算,一个线程负责计算 TM * TN 个元素
#pragma unroll
for(int k = 0; k < BK; ++k)
{
#pragma unroll
for(int m = 0; m < TM; ++m)
{
#pragma unroll
for(int n = 0; n < TN; ++n)
{
int comp_a_smem_m = ty * TM + m;
int comp_b_smem_n = tx * TN + n;
r_c[m][n] += s_a[comp_a_smem_m][k] * s_b[k][comp_b_smem_n];
}
}
}

__syncthreads();

}

// 写回全局内存
#pragma unroll
for(int i = 0; i < TM; ++i)
{
int store_c_gmem_m = by * BM + ty * TM + i;
#pragma unroll
for(int j = 0; j < TN; j += 4)
{
int store_c_gmem_n = bx * BN + tx * TN + j;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][j]);
}
}
}