Post

CUDA-Operators-5-RMSNorm

CUDA-Operators-5-RMSNorm

1.RMSNorm

RMSNorm 是一种归一化操作,使用 均方根(Root Mean Square, RMS)实现归一化,计算公式为:$RMSNorm(x)=\frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d}x_i^2} +\epsilon }$,代码表示如下。

1
x = x / torch.sqrt(torch.mean(x ** 2, dim = 1, keepdim=True) + 1e-5)

1.1.Implement

RMSNorm 实现分为 2 Pass,Pass 1 计算均方根,Pass 2 对每个元素进行 Normalization。均方根的计算涉及归约的过程,采用 Block 内归,具体代码如下所示。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#define WARP_SIZE 32
#define FLOAT4(ptr) (reinterpret_cast<float*>(&(ptr))[0])

template<const int WarpSize>
__device__ __inline__ float warpReduceSum(float val){
#pragma unroll
    for (int mask = WarpSize >> 1; mask >= 1; mask >>= 1){
        val += __shfl_xor_sync(0xffffffff, val, mask);
    }
    return val;
}


template<const int BN, const int BK>
__global__ void rmsnorm_f32_kernel(float* x, float* out, const int N, const int K){
  const int WARP_NUM = BK / WARP_SIZE;
  __shared__ float smem[BN][WARP_NUM];

  int tx = threadIdx.x;
  int ty = threadIdx.y;
  int bx = blockIdx.x;

  int lane = tx % WARP_SIZE;
  int warp_id = tx / WARP_SIZE;

  float* cur_line_addr = x + (bx * BN + ty) * K;
  float val = cur_line_addr[tx] * cur_line_addr[tx];
  for (int i = tx + BK; i < K; i += BK){
    val += cur_line_addr[i] * cur_line_addr[i];
  }

  val = warpReduceSum<WARP_SIZE>(val);

  if (lane == 0)
    smem[ty][warp_id] = val;
  __syncthreads();

  if (tx == 0){
    float norm = smem[ty][0] / K;
    for (int i = 1; i < WARP_NUM; ++i){
      norm += smem[ty][i] / K;
    }
    smem[ty][0] = norm;
  }
  __syncthreads();

  float norm_val = rsqrtf(smem[ty][0] + 1e-5);
  for (int i = tx; i < K; i += BK){
    out[(bx * BN + ty) * K + i] = cur_line_addr[i] * norm_val;
  }
}

上面的实现中,每个 Block 计算 BN 行元素,每行元素由 BK 个线程计算,在 K 较小时性能较高,结果如下图所示。但是当 K 极大时,每个线程需要访问 K / BK 次 global mem,可以通过 Block 算 (BN, BK’) 小块,然后 Block 之间归约进行优化。

Desktop View

Reference

[1] DefTruth, Many Others. LeetCUDA: A Modern CUDA Learn Notes with PyTorch for Beginners. 2025. https://github.com/xlite-dev/LeetCUDA.git.

This post is licensed under CC BY 4.0 by the author.

Trending Tags