Puzzle 8: Long Softmax
前置知识-了解softmax的数学原理
1. Softmax 是什么?
Softmax 是一种常用于分类问题的数学函数,它的作用是把一个任意实数向量(例如神经网络的输出 logits)转换成概率分布。
假设我们有一个向量
Softmax 的定义是:
也就是说:
- 每个元素都先取指数 (保证结果为正);
- 再除以所有指数和(保证总和为 1);
- 因此输出的是一个“概率分布”向量。
2. 为什么要用 Softmax
Softmax 有两个核心作用:
- 归一化 (Normalization)
输出结果的所有分量都在 (0,1) 之间,而且总和为 1。
👉 这就让我们可以把输出当作属于各类别的概率。 - 放大差异 (Exponential Scaling)
因为用了指数函数,较大的值会被“放大”,较小的值会被“压缩”。
这有助于模型更明确地区分类别。
3. 例子
假设神经网络输出:
z = [2, 1, 0]
计算 Softmax:

表示模型认为:
- 类别 1 的概率为 66.5%
- 类别 2 的概率为 24.5%
- 类别 3 的概率为 9.0%
4. 数值稳定性(实际实现)
在代码中通常写成:
因为直接计算
可能会溢出(如果 z 很大)。
5. 与交叉熵 (Cross-Entropy) 的关系
在分类任务中常用 Softmax + CrossEntropyLoss:
- Softmax 把 logits 转为概率;
- Cross-Entropy 衡量预测分布和真实分布的距离;
- 它们通常在框架中一起实现(比如 PyTorch 的 CrossEntropyLoss 内部已经包含 Softmax)。
6. 英文总结
Softmax converts a vector of raw scores (logits) into probabilities that sum to 1.
It emphasizes larger values while keeping the output interpretable as probabilities.
Formula:
Used for:
- The output layer of a multi-class classifier
- Combined with cross-entropy loss
矩阵的softmax
1. 矩阵的 Softmax 是什么意思?
当我们说“对一个矩阵求 Softmax”时,首先要明确——
Softmax 是对哪个维度进行的。
因为 Softmax 操作本质上是“沿着一个维度,把一组数变成概率分布”。
所以在矩阵(二维数据)里,常见的有两种情况:
- 按行 (row-wise) 求 Softmax
- 按列 (column-wise) 求 Softmax
2.例子说明
假设矩阵
情况 1:按行求 Softmax(最常见)
对每一行单独做 Softmax:
第 1 行:
$$
[1,2,3] \Rightarrow e^{[1,2,3]} = [2.72, 7.39, 20.09]
$$
第 2 行:
得到结果矩阵:
情况 2:按列求 Softmax
对每一列单独做 Softmax:
| 列号 | 原值 | Softmax 后 |
|---|---|---|
| 第 1 列 | [1,2] | [0.27, 0.73] |
| 第 2 列 | [2,4] | [0.12, 0.88] |
| 第 3 列 | [3,6] | [0.05, 0.95] |
得到:
3、在代码中实现(以 NumPy 为例)
import numpy as np
A = np.array([[1, 2, 3],
[2, 4, 6]])
# 按行求 Softmax
row_softmax = np.exp(A - np.max(A, axis=1, keepdims=True))
row_softmax = row_softmax / np.sum(row_softmax, axis=1, keepdims=True)
# 按列求 Softmax
col_softmax = np.exp(A - np.max(A, axis=0, keepdims=True))
col_softmax = col_softmax / np.sum(col_softmax, axis=0, keepdims=True)
4. 总结
| 类型 | 操作维度 | 常见用途 |
|---|---|---|
| 行 Softmax | 对每一行做 | 每行代表一个样本的 logits(分类任务常用) |
| 列 Softmax | 对每一列做 | 每列代表一个特征在不同样本上的权重 |
算子题目及实现
Softmax of a batch of logits.
输入matrix 的大小是
.
对矩阵按行求softmax
-
Uses one program block axis. Block size B0 represents the batch of x of length N0.
Block logit length T. Process it B1 < T elements at a time.

Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton they recommend not using exp but instead using exp2. You need the identity
Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever. Hint: you will find this identity useful:
题目框架
def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]:
x_max = x.max(1, keepdim=True)[0]
x = x - x_max
x_exp = x.exp()
return x_exp / x_exp.sum(1, keepdim=True)
@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
pid_0 = tl.program_id(0)
log2_e = 1.44269504
return
test(softmax_kernel, softmax_spec, B={"B0": 1, "B1":32},
nelem={"N0": 4, "N1": 32, "T": 200})
My Solution 1
按照这种思路,我就一次实现了softmax的功能正确性
def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]:
x_max = x.max(1, keepdim=True)[0]
x = x - x_max
x_exp = x.exp()
return x_exp / x_exp.sum(1, keepdim=True)
@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
pid_0 = tl.program_id(0)
log2_e = 1.44269504
block_idx = pid_0 * B0 + tl.arange(0, B0)
mask_b = block_idx < N0
# max_i
init = tl.full((B0,), 0, tl.float32)
m = tl.full((B0, ), 0, tl.float32)
for i in range(0, T, B1):
col_idx = tl.arange(0, B1) + i
mask_col = col_idx < T
offset_x = block_idx[:, None] * T + col_idx[None, :]
mask_x = mask_b[:, None] & mask_col[None, :]
x = tl.load(x_ptr + offset_x, mask=mask_x)
m = tl.max(x, axis=1)
m = tl.maximum(m, init)
# sum of e(x-m)
x_exp = tl.full((B0, B1), 0, tl.float32)
x_sum = tl.full((B0, ), 0, tl.float32)
for i in range(0, T, B1):
col_idx = tl.arange(0, B1) + i
mask_col = col_idx < T
offset_x = block_idx[:, None] * T + col_idx[None, :]
mask_x = mask_b[:, None] & mask_col[None, :]
x = tl.load(x_ptr + offset_x, mask=mask_x)
x_ = x - m
x_exp = tl.exp(x_)
x_sum = tl.sum(x_exp, axis=1) + x_sum
# e(x-m)/x_sum
x_exp = tl.full((B0, B1), 0, tl.float32)
for i in range(0, T, B1):
col_idx = tl.arange(0, B1) + i
mask_col = col_idx < T
offset_x = block_idx[:, None] * T + col_idx[None, :]
mask_x = mask_b[:, None] & mask_col[None, :]
x = tl.load(x_ptr + offset_x, mask=mask_x)
x_ = x - m
x_exp = tl.exp(x_)/x_sum
tl.store(z_ptr + offset_x, x_exp, mask=mask_x)
return
test(softmax_kernel, softmax_spec, B={"B0": 1, "B1":32},
nelem={"N0": 4, "N1": 32, "T": 200})
可视化
可以看书数据在列方向以block(B1)大小进行计算
TODO
当然这题应该还有更优化的解题方法,后面可以继续做一些优化的实现

