大数跨境
0
0

Triton puzzle 8 Long Softmax

Triton puzzle 8 Long Softmax Angela的外贸日常
2025-10-21
4
导读:Puzzle 8: Long Softmax 前置知识-了解softmax的数学原理 1. Softm

Puzzle 8: Long Softmax

前置知识-了解softmax的数学原理

1. Softmax 是什么?

Softmax 是一种常用于分类问题的数学函数,它的作用是把一个任意实数向量(例如神经网络的输出 logits)转换成概率分布
假设我们有一个向量

Softmax 的定义是:

也就是说:

  • 每个元素都先取指数 (保证结果为正);
  • 再除以所有指数和(保证总和为 1);
  • 因此输出的是一个“概率分布”向量。

2. 为什么要用 Softmax

Softmax 有两个核心作用:

  1. 归一化 (Normalization)
    输出结果的所有分量都在 (0,1) 之间,而且总和为 1。
    👉 这就让我们可以把输出当作属于各类别的概率
  2. 放大差异 (Exponential Scaling)
    因为用了指数函数,较大的值会被“放大”,较小的值会被“压缩”。
    这有助于模型更明确地区分类别。

3. 例子

假设神经网络输出:

z = [2, 1, 0]

计算 Softmax:


表示模型认为:

  • 类别 1 的概率为 66.5%
  • 类别 2 的概率为 24.5%
  • 类别 3 的概率为 9.0%

4. 数值稳定性(实际实现)

在代码中通常写成:

因为直接计算 可能会溢出(如果 z 很大)。

5. 与交叉熵 (Cross-Entropy) 的关系

在分类任务中常用 Softmax + CrossEntropyLoss

  • Softmax 把 logits 转为概率;
  • Cross-Entropy 衡量预测分布和真实分布的距离;
  • 它们通常在框架中一起实现(比如 PyTorch 的 CrossEntropyLoss 内部已经包含 Softmax)。

6. 英文总结

Softmax converts a vector of raw scores (logits) into probabilities that sum to 1.
It emphasizes larger values while keeping the output interpretable as probabilities.
Formula:

Used for:

  • The output layer of a multi-class classifier
  • Combined with cross-entropy loss

矩阵的softmax

1. 矩阵的 Softmax 是什么意思?

当我们说“对一个矩阵求 Softmax”时,首先要明确——
Softmax 是对哪个维度进行的。
因为 Softmax 操作本质上是“沿着一个维度,把一组数变成概率分布”。
所以在矩阵(二维数据)里,常见的有两种情况:

  1. 按行 (row-wise) 求 Softmax
  2. 按列 (column-wise) 求 Softmax

2.例子说明

假设矩阵

情况 1:按行求 Softmax(最常见)

对每一行单独做 Softmax:
第 1 行:
$$
[1,2,3] \Rightarrow e^{[1,2,3]} = [2.72, 7.39, 20.09]
$$

第 2 行:

得到结果矩阵:

情况 2:按列求 Softmax

对每一列单独做 Softmax:

列号 原值 Softmax 后
第 1 列 [1,2] [0.27, 0.73]
第 2 列 [2,4] [0.12, 0.88]
第 3 列 [3,6] [0.05, 0.95]

得到:


3、在代码中实现(以 NumPy 为例)

import numpy as np

A = np.array([[1, 2, 3],
              [2, 4, 6]])

# 按行求 Softmax
row_softmax = np.exp(A - np.max(A, axis=1, keepdims=True))
row_softmax = row_softmax / np.sum(row_softmax, axis=1, keepdims=True)

# 按列求 Softmax
col_softmax = np.exp(A - np.max(A, axis=0, keepdims=True))
col_softmax = col_softmax / np.sum(col_softmax, axis=0, keepdims=True)

4. 总结

类型 操作维度 常见用途
行 Softmax 对每一行做 每行代表一个样本的 logits(分类任务常用)
列 Softmax 对每一列做 每列代表一个特征在不同样本上的权重

算子题目及实现

Softmax of a batch of logits.
输入matrix 的大小是 .
对矩阵按行求softmax

-
Uses one program block axis. Block size B0 represents the batch of x of length N0.
Block logit length T.   Process it B1 < T elements at a time.

Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton they recommend not using exp but instead using exp2. You need the identity


Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever. Hint: you will find this identity useful:

题目框架

def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]:
    x_max = x.max(1, keepdim=True)[0]
    x = x - x_max
    x_exp = x.exp()
    return x_exp / x_exp.sum(1, keepdim=True)

@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    log2_e = 1.44269504
    return

test(softmax_kernel, softmax_spec, B={"B0": 1, "B1":32},
     nelem={"N0": 4, "N1": 32, "T": 200})

My Solution 1

思路

按照softmax的计算公式,实现最初始的解法,可以实现功能正确性
一个block块实现了(B0, T)的softmax, T长度以B1为一个块来循环操作

  1. 一次for 循环求行的最大值m
  2. 一次for 循环求行的和
  3. 一次for 循环求得

按照这种思路,我就一次实现了softmax的功能正确性

def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]:
    x_max = x.max(1, keepdim=True)[0]
    x = x - x_max
    x_exp = x.exp()
    return x_exp / x_exp.sum(1, keepdim=True)

@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    log2_e = 1.44269504
    block_idx = pid_0 * B0 + tl.arange(0, B0)
    mask_b = block_idx < N0 

    # max_i 
    init = tl.full((B0,), 0, tl.float32)
    m = tl.full((B0, ), 0, tl.float32)
    for i in range(0, T, B1): 
      col_idx = tl.arange(0, B1) + i
      mask_col = col_idx < T
      offset_x = block_idx[:, None] * T + col_idx[None, :]
      mask_x = mask_b[:, None] & mask_col[None, :]
      x = tl.load(x_ptr + offset_x, mask=mask_x)
      m = tl.max(x, axis=1) 
      m = tl.maximum(m, init)

    # sum of  e(x-m)
    x_exp = tl.full((B0, B1), 0, tl.float32)
    x_sum = tl.full((B0, ), 0, tl.float32)
    for i in range(0, T, B1): 
      col_idx = tl.arange(0, B1) + i
      mask_col = col_idx < T
      offset_x = block_idx[:, None] * T + col_idx[None, :]
      mask_x = mask_b[:, None] & mask_col[None, :]
      x = tl.load(x_ptr + offset_x, mask=mask_x)
      x_ = x - m 
      x_exp = tl.exp(x_)
      x_sum = tl.sum(x_exp, axis=1) + x_sum

    
    # e(x-m)/x_sum
    x_exp = tl.full((B0, B1), 0, tl.float32)
    for i in range(0, T, B1): 
      col_idx = tl.arange(0, B1) + i
      mask_col = col_idx < T
      offset_x = block_idx[:, None] * T + col_idx[None, :]
      mask_x = mask_b[:, None] & mask_col[None, :]
      x = tl.load(x_ptr + offset_x, mask=mask_x)
      x_ = x - m 
      x_exp = tl.exp(x_)/x_sum
      tl.store(z_ptr + offset_x, x_exp, mask=mask_x)

    return

test(softmax_kernel, softmax_spec, B={"B0": 1, "B1":32},
     nelem={"N0": 4, "N1": 32, "T": 200})

可视化

可以看书数据在列方向以block(B1)大小进行计算

TODO

当然这题应该还有更优化的解题方法,后面可以继续做一些优化的实现

【声明】内容源于网络
0
0
Angela的外贸日常
跨境分享间 | 长期积累专业经验
内容 45910
粉丝 1
Angela的外贸日常 跨境分享间 | 长期积累专业经验
总阅读246.0k
粉丝1
内容45.9k