Post

CUDA-Operators-4-Softmax

CUDA-Operators-4-Softmax

1. Softmax 基本实现

Softmax 将一个数值向量归一化为一个概率分布向量,且各个概率之和为 1。Softmax 可以用来作为神经网络的最后一层,用于多分类问题的输出。

原始 Softmax 中 $\sum e^{x_i}$ 容易导致数值溢出,通常使用 Safe Softmax,即让 $x_i - max(x)$,以防止数值溢出,具体公式表达如下所示。

\[m = max(x)\] \[Softmax(x) = \frac{e^{x_i-m}}{\sum e^{x_i-m}}\]

Softmax 中需要执行三次循环,最关键是 max 和 sum 两个 reduce 操作。要完成计算需要读入三次,写一次,算法如下所示。

Algorithm 1: Safe Softmax
1  $m \gets -\infty$
for $i \gets 1$ to $N$ do
3     $m \gets \max(m, x_i)$
4  $s \gets 0$
for $i \gets 1$ to $N$ do
6     $s \gets s + e^{x_i - m}$
for $i \gets 1$ to $N$ do
8     $y_i \gets \dfrac{e^{x_i - m}}{s}$

1.1 Naive GPU 实现 (V1)

由于 Norm 过程依赖归约的结果,最 Naive 的实现就是一个线程算结果中一行,一个 block 算 BLOCK_SIZE 行,代码如下所示。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/**
 * Every block calculate `BLOCK_SIZE` lines of result, one thread calculate one line
 * inp is (N, C)
 * out is (N, C)
 * BLOCK_SIZE = 32
 * dim3 block(BLOCK_SIZE)
 * dim3 grid((N + BLOCK_SIZE -1) / BLOCK_SIZE)
 */
__global__ void softmax_forward_naive_f32_kernel(
  float* out, const float* inp, int N, int C
) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < N) {
        const float* inp_row = inp + i * C;
        float* out_row = out + i * C;

        float maxval = -INFINITY;
        for (int j = 0; j < C; j++) {
            if (inp_row[j] > maxval) {
                maxval = inp_row[j];
            }
        }
        double sum = 0.0;
        for (int j = 0; j < C; j++) {
            out_row[j] = expf(inp_row[j] - maxval);
            sum += out_row[j];
        }
        for (int j = 0; j < C; j++) {
            out_row[j] /= (float)sum;
        }
    }
}

1.2 Shared mem 实现 (V2)

V1 版本实现过于 Naive,每个线程算一行,导致:(1)单个线程任务很重,并行性不足;(2)三次访存都从 global mem 取,访存效率低。

下面的优化版本中,让每个 block 计算一行元素,并使用 shared mem 缓存归约结果。代码如下所示。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
/**
 * Every block calculate one line.
 * inp is (N, C)
 * 
 * BLOCK_SZIE = 512
 * dim3 grid(N)
 * dim3 block(BLOCK_SIZE)
 * size_t smem_size = BLOCK_SIZE * sizeof(float)
 * softmax_forward_smem_f32_kernel<<<grid, block, smem_size>>>(out, inp, N, C);
 */
__global__ void softmax_forward_smem_f32_kernel(
  float* out, const float* inp, int N, int C
){
  extern __shared__ float shared[];
  int idx = blockIdx.x;   // ranges [0, N)
  int tid = threadIdx.x;  // ranges [0, block_size)
  int block_size = blockDim.x;
  const float* x = inp + idx * C;

  float maxval = -INFINITY;
  for (int i = tid; i < C; i += block_size) {
      maxval = fmaxf(maxval, x[i]);
  }
  shared[tid] = maxval;
  __syncthreads();

  // max reductions
  for (int stride = block_size / 2; stride >= 1; stride /= 2) {
      if (tid < stride) {
          shared[tid] = fmaxf(shared[tid], shared[tid + stride]);
      }
      __syncthreads();
  }

  float offset = shared[0];
  float sumval = 0.0f;
  for (int i = tid; i < C; i += block_size) {
      sumval += expf(x[i] - offset);
  }
  shared[tid] = sumval;
  __syncthreads();
  
  // sum reduction
  for (int stride = block_size / 2; stride >= 1; stride /= 2) {
      if (tid < stride) {
          shared[tid] += shared[tid + stride];
      }
      __syncthreads();
  }

  // broadcast the sum to all threads in the block
  float sum = shared[0];
  for (int i = tid; i < C; i += block_size) {
      out[idx * C + i] = expf(x[i] - offset) / sum;
  }
}

1.3 Shared mem + Warp Reduce 优化 (V3)

归约过程可以使用 warp reduce 直接在寄存器中实现优化,分块为 grid((N + BN - 1) / BN)block(BC, BN)。即每个 block 计算 BN 行结果,每个线程处理 (C + BC - 1) / BC 个数,线程视角的归约计算流程如下所示。 block 归约也是同样的过程:warp 内归约,然后写入 shared mem, 最后 shared mem 广播。

  • 每个线程计算 (C + BC - 1) / BC 个数,即初步归约到 0…BC-1 线程;
  • 对 0…BC-1 线程的每个 warp 做 warp reduce,每个 warp 的结果缓存在 smem 中;
  • 然后 0 号线程归约每个 warp 的结果,得到归约结果,写入 0 号线程对应的 smem 中;
  • 需要使用时,从 smem 中取出,即 smem 广播的过程;

Desktop View

Fig 1. Warp Reduction 示例

具体代码如下所示。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#define WarpSize 32
/**
 * Every Block calculate BN lines
 * BN = 4
 * BC = 256 -> a multiple of 32
 * inp(N, C)
 * grid((N + BN - 1) / BN)
 * block(BC, BN)
 */
template <const int BN, const int BC>
__global__ void softmax_forward_warp_smem_f32_kernel(
  float* out, const float* inp, int N, int C
){
  // shared memory is used for inter-warp reduction
  __shared__ float smem[BN][2][BC];
  int tx = threadIdx.x;
  int ty = threadIdx.y;
  int bx = blockIdx.x;

  int n = bx * BN + ty;
  int lane_id = tx % WarpSize;
  int warp_id = tx / WarpSize;

  if (n >= N) return;

  const int niter = BC / WarpSize;
  const float* gmem = inp + n * C;

  // 1.max reduction
  float max = -INFINITY;
  for (int i = tx; i < C; i += BC){
    max = fmaxf(max, gmem[i]);
  }
  
  max = fmaxf(max, __shfl_down_sync(0xffffffff, max, 16));
  max = fmaxf(max, __shfl_down_sync(0xffffffff, max, 8));
  max = fmaxf(max, __shfl_down_sync(0xffffffff, max, 4));
  max = fmaxf(max, __shfl_down_sync(0xffffffff, max, 2));
  max = fmaxf(max, __shfl_down_sync(0xffffffff, max, 1));

  if (lane_id == 0) smem[ty][0][warp_id] = max;
  __syncthreads();

  if (tx == 0) {
    int maxval = smem[ty][0][0];
    for (int i = 1; i < niter; ++i){
      maxval = fmaxf(maxval, smem[ty][0][i]);
    }
    smem[ty][0][0] = maxval;
  }
  __syncthreads();

  // 2.sum reduction
  float offset = smem[ty][0][0];  // broadcast maxval
  float sum = 0;
  for (int i = tx; i < C; i += BC){
    sum += expf(gmem[i] - offset);
  }

  sum += __shfl_xor_sync(0xffffffff, sum, 16);
  sum += __shfl_xor_sync(0xffffffff, sum, 8);
  sum += __shfl_xor_sync(0xffffffff, sum, 4);
  sum += __shfl_xor_sync(0xffffffff, sum, 2);
  sum += __shfl_xor_sync(0xffffffff, sum, 1);

  if (lane_id == 0) smem[ty][1][warp_id] = sum;
  __syncthreads();
  if (tx == 0){
    float sumval = smem[ty][1][0];
    for (int i = 1; i < niter; ++i){
      sumval += smem[ty][1][i];
    }
    smem[ty][1][0] = sumval;
  }
  __syncthreads();

  // 3.norm
  float sumval = smem[ty][1][0];
  for (int i = tx; i < C; i += BC){
    out[n * C + i] = expf(gmem[i] - offset) / sumval;
  }
}

2. Softmax 高阶优化

Online softmax [4] 通过累加迭代计算,能够省略一次循环并实现 safe softmax 的等价计算。核心思想是利用指数计算特性,把第 i 步 max 值套入到 i-1 步 sum exp 中,再加上当前的 $exp(x_i-max_i)$ 即为第 i 步 sum exp。

Algorithm 2: Online Softmax
1  $m_0 \gets -\infty$
2  $s_0 \gets 0$
for $i \gets 1$ to $N$ do
4   $m_i \gets \max(m_{i-1}, x_i)$
5   $s_i \gets s_{i-1} \cdot e^{\,m_{i-1} - m_i\,} + e^{\,x_i - m_i\,}$
end for
for $i \gets 1$ to $N$ do
8   $y_i \gets \dfrac{e^{x_i - m_N}}{s_N}$
end for

NOTE: logSoftmax 同样能用上面的 sum exp 性质进行优化。

\[logSoftmax(x) = log(\frac{e^{x_i-x_{max}}}{\sum_j(e^{x_j-x_{max}})}) = x_i - x_m -log(\sum_j e^{x_j-x_{max}})\]

2.1 协作组 + 结构体实现 (V4)

下面是 协作组 + 结构体 的简单实现,当 C 维度较大时,单个线程任务过重,导致线程资源紧张,并行性会下降。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;

struct __align__(8) SumMax
{
  float maxval;
  float sum;
};

__device__ __inline__ SumMax reduce_sum_max_op(SumMax a, SumMax b) {
  bool a_bigger = (a.maxval > b.maxval);
  SumMax bigger_m = a_bigger ? a : b;
  SumMax smaller_m = a_bigger ? b : a;
  SumMax res;
  res.maxval = bigger_m.maxval;
  res.sum = bigger_m.sum + smaller_m.sum * expf(smaller_m.maxval - bigger_m.maxval);
  return res;
}

/**
 * Every warp calculate one line.
 * BLOCK → a multiple of 32
 * dim3 grid((N + BLOCK_SIZE / 32 - 1) / (BLOCK_SIZE / 32))
 * dim3 block(BLOCK_SIZE)
 */
__global__ void softmax_forward_online_kernel(
  float* out, const float* inp, int N, int C
){
	cg::thread_block block = cg::this_thread_block();
	cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
	int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
	if (idx >= N) {
		return;
	}
	const float* x = inp + idx * C;
  SumMax sm_partial;
	sm_partial.maxval = -INFINITY;
	sm_partial.sum = 0.0f;

	for (int i = warp.thread_rank(); i < C; i += warp.size()) {
		sm_partial = reduce_sum_max_op(sm_partial, { x[i], 1.0f });
	}

	SumMax sm_total = cg::reduce(warp, sm_partial, reduce_sum_max_op);

  // __stcs → st.global.cs SASS 忽略 L1,写到 global mem (只走 L2/DRAM)
	for (int i = warp.thread_rank(); i < C; i += warp.size()) {
    __stcs(out + idx * C + i, expf(x[i] - sm_total.maxval) / sm_total.sum);
	}
}

上面的代码难点在 reduce_sum_max_op 函数中,这个 reduce 代码本质是合并两段 exp sum。假设 A 和 B 是两个 SumMax 数据结构,我们的目标是求 A 和 B 最大值下的 safe exp sum。

\[A.sum = \sum e^{x_i-a.maxval},x_i\in A\] \[B.sum = \sum e^{x_j-b.maxval},x_j\in B\]

首先,我们需要比较最大值,假设 A 大,此时我们的目标是将 B 转化为 $\sum e^{x_j-a.maxval}$ 形式。因此,根据指数性质,得:

\[\sum e^{x_j-a.maxval} = \sum e^{x_j-b.maxval}\cdot e^{b.maxval-a.maxval} = b.sum\cdot e^{b.maxval-a.maxval}\]

所以合并计算过程如下所示:

\[sum'= a.sum + b.sum\cdot e^{b.maxval-a.maxval}\]

Reference

[1] LeetCUDA

[2] ops(2):SoftMax算子的 CUDA 实现.

[3] Attention优化][2w字]📚原理篇: 从Online-Softmax到FlashAttention V1/V2/V3.

[4] Online normalizer calculation for softmax

This post is licensed under CC BY 4.0 by the author.

Trending Tags