大数跨境
0
0

通透!RNN vs Transformer !!

通透!RNN vs Transformer !! 机器学习和人工智能AI
2025-11-30
0

哈喽,大家好~

咱们今天从并行处理能力、记忆结构、依赖建模方式等等方面的对比,完整阐述RNN 和 Transformer 的不同~

首先,RNN(循环神经网络)与 Transformer 是两种深度学习中处理序列数据的重要模型结构。

一、基本结构对比

1. RNN

RNN 是一种按时间步递归更新的神经网络,适用于处理序列数据,例如文本、时间序列等。

核心公式:

在时间步   时,隐藏状态的更新:

其中:

  • :当前输入
  • :前一时刻的隐藏状态(记忆)
  • :权重矩阵
  • :激活函数(通常为 tanh 或 ReLU)

输出通常是:

RNN 的核心特性是隐藏状态作为时间的记忆传递机制

2. Transformer

Transformer 是一种完全基于注意力机制(Self-Attention)的模型架构,彻底抛弃了递归结构。它依赖位置编码(Positional Encoding)来注入顺序信息,并可并行处理序列。

核心公式:多头自注意力机制(Multi-Head Attention)

对于输入序列  ,先线性变换得到:

计算注意力权重:

多头机制:

每个头为:

二、其他对比

并行处理能力对比

特性
RNN
Transformer
是否支持并行
否(时间步依赖)
是(所有位置可同时计算)
训练速度
慢(逐步计算)
快(高度并行)
长序列处理能力
差(梯度消失/爆炸)
强(全局注意力)

其中:

RNN 必须从      逐步处理,前一状态是后一状态的输入,无法并行。

Transformer 没有时间依赖,所有位置的表示可以一次性计算出来,极大提高训练速度。

记忆结构对比

特性
RNN
Transformer
记忆机制
隐藏状态(
注意力上下文向量(基于所有位置的加权和)
可访问历史范围
有限(随时间衰减)
全局(所有位置)
长期依赖建模
较弱
强(任意位置之间都能交互)
  • RNN 的隐藏状态容易因梯度消失/爆炸而丧失早期信息,即记忆不稳定。
  • Transformer 利用注意力机制,在每一层都能与全序列交互,有效捕捉长距离依赖。

依赖建模方式对比

RNN:

  • 依赖是时间上的递归建模
  • 关系是隐式建模,无法直接控制依赖强弱。

Transformer:

  • 使用注意力权重显式建模依赖:
  •  直接表示位置   对位置   的关注程度,是可解释的依赖建模。

位置建模

特性
RNN
Transformer
顺序感知方式
隐含在递归中(天然有顺序)
显式位置编码
位置编码方式
无需额外编码
使用 Positional Encoding(正余弦 or Learnable)

Transformer 的位置编码:

完整案例

这里案例,会详细对比 RNN 与 Transformer 在:

  • 并行处理能力
  • 记忆结构
  • 长依赖建模方式

方面的表现差异,结合训练过程的图形可视化给出清晰结论。

数据集

我们生成一种有规律但包含长依赖结构的序列任务,例如:

给定一个长度为 20 的数字序列,预测第 21 项。规则是:

这种任务依赖于多个、远间隔的位置,适合测试记忆能力和长依赖建模。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# 配置
torch.manual_seed(42)
device = torch.device("cuda"if torch.cuda.is_available() else"cpu")

# 1. 数据
def generate_data(n_samples=1000, seq_len=20):
    X = np.random.rand(n_samples, seq_len)
    y = X[:, 2] + X[:, 6] + X[:, 14] + np.sin(X[:, 0]) + np.cos(X[:, 17])
    return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32).unsqueeze(1)

X, y = generate_data(5000)
X_train, y_train = X[:4000], y[:4000]
X_test, y_test = X[4000:], y[4000:]

# 2. RNN 模型
class RNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__()
        self.rnn = nn.RNN(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = x.unsqueeze(-1)  # shape: (B, T, 1)
        _, h_n = self.rnn(x)  # h_n: (num_layers, B, hidden_dim)
        out = self.fc(h_n[-1])
        return out

# 3. Transformer 模型
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

class TransformerModel(nn.Module):
    def __init__(self, input_dim=1, d_model=64, nhead=4, num_layers=2):
        super().__init__()
        self.linear_in = nn.Linear(input_dim, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, 1)

    def forward(self, x):
        x = x.unsqueeze(-1)  # (B, T, 1)
        x = self.linear_in(x)
        x = self.pos_enc(x)
        x = self.transformer(x)
        out = self.fc_out(x[:, -1])
        return out

# 4. 训练函数
def train_model(model, X_train, y_train, X_test, y_test, epochs=50):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    train_loss, test_loss = [], []

    for epoch in tqdm(range(epochs)):
        model.train()
        optimizer.zero_grad()
        pred = model(X_train.to(device))
        loss = criterion(pred, y_train.to(device))
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())

        model.eval()
        with torch.no_grad():
            val = model(X_test.to(device))
            loss_val = criterion(val, y_test.to(device))
            test_loss.append(loss_val.item())

    return train_loss, test_loss, model

# 5. 训练 & 可视化
rnn_model = RNNModel(1641)
transformer_model = TransformerModel()

rnn_train_loss, rnn_test_loss, rnn_model = train_model(rnn_model, X_train, y_train, X_test, y_test)
tr_train_loss, tr_test_loss, transformer_model = train_model(transformer_model, X_train, y_train, X_test, y_test)

# 6. 图形展示
plt.figure(figsize=(126))
plt.plot(rnn_train_loss, label='RNN Train Loss', color='red')
plt.plot(rnn_test_loss, label='RNN Test Loss', color='orange')
plt.plot(tr_train_loss, label='Transformer Train Loss', color='blue')
plt.plot(tr_test_loss, label='Transformer Test Loss', color='green')
plt.title('Loss Curve Comparison')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.tight_layout()
plt.show()

# 预测值可视化
def plot_predictions(model, X, y, title):
    model.eval()
    with torch.no_grad():
        preds = model(X.to(device)).cpu().numpy()
    plt.figure(figsize=(105))
    plt.scatter(range(len(y)), y.numpy(), label='True', alpha=0.6, color='black')
    plt.scatter(range(len(preds)), preds, label='Predicted', alpha=0.6, color='cyan')
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

plot_predictions(rnn_model, X_test, y_test, "RNN Predictions")
plot_predictions(transformer_model, X_test, y_test, "Transformer Predictions")

# 注意力热力图
def plot_attention_heatmap():
    dummy = torch.rand(12064)
    attn_layer = nn.MultiheadAttention(644, batch_first=True)
    attn_output, attn_weights = attn_layer(dummy, dummy, dummy)
    plt.figure(figsize=(86))
    sns.heatmap(attn_weights[0].cpu().detach().numpy(), cmap='coolwarm')
    plt.title("Sample Attention Map (Transformer)")
    plt.xlabel("Key Position")
    plt.ylabel("Query Position")
    plt.tight_layout()
    plt.show()

plot_attention_heatmap()

图1:Loss Curve

可以看到, RNN 和 Transformer 在训练与测试集上的损失曲线,Transformer 更快收敛、过拟合更少。

图2:预测值分布(RNN)

预测点呈现偏离,说明无法很好捕捉非线性依赖。

图3:预测值分布(Transformer)

预测点与真实值高度重合,说明 Transformer 更好学习了复杂依赖。

图4:Attention 热力图

Transformer 模型关注多个位置,与任务的「x3, x7, x15, x18」匹配,显示其优越的依赖建模能力。

本项目详细实现并可视化对比了 RNN 与 Transformer 两大架构在处理序列预测任务中的关键差异。从实验结果与图表中我们可以得出结论:

  • Transformer 明显优于 RNN 在长期依赖建模方面。
  • 在资源允许下,Transformer 应作为首选方案,尤其是当任务需要理解全局上下文。
  • RNN 仍然适用于低资源、低时延场景。

最终,深刻理解两种结构的内在机制与适用场景,才是构建高效模型系统的关键。

最后

宝子们,快来领取16大块的内容,124个算法问题的总结,完整的机器学习小册,免费领取~
领取:备注「算法小册」即可~
顺便加个好友,微观朋友圈,满满的干货!~

【声明】内容源于网络
0
0
机器学习和人工智能AI
让我们一起期待 AI 带给我们的每一场变革!推送最新行业内最新最前沿人工智能技术!
内容 333
粉丝 0
机器学习和人工智能AI 让我们一起期待 AI 带给我们的每一场变革!推送最新行业内最新最前沿人工智能技术!
总阅读220
粉丝0
内容333