/ Transformer  自注意力  大模型  LLM  KV-Cache  Flash-Attention  深度学习  AI工程 

深入理解 Transformer:从自注意力机制到大模型工程优化实战


封面

一、为什么 Transformer 彻底改变了 AI 格局

2017 年,Google 发表论文《Attention Is All You Need》,提出了完全基于注意力机制的 Transformer 架构,摒弃了 RNN/LSTM 的顺序依赖,开创了并行化训练大规模模型的新纪元。如今,ChatGPT、Claude、Gemini 等大语言模型(LLM)无一例外地以 Transformer 作为核心骨架。

  • 并行化训练:告别 RNN 的时序依赖,可在 GPU/TPU 上高效并行

  • 长程依赖建模:Self-Attention 直接建立序列任意两个位置的关联

  • 可扩展性强:从 1 亿到 1 万亿参数,架构基本不变

二、自注意力机制(Self-Attention)原理精讲

自注意力的核心思想是:对于序列中的每个 token,计算它与所有其他 token 的相关性权重,加权求和得到新的表示。整个过程可用三个矩阵 Q(Query)、K(Key)、V(Value) 描述:

import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q, K, V: (batch, heads, seq_len, d_k)
    """
    d_k = Q.size(-1)
    # 计算注意力分数
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    # Softmax 归一化
    attn_weights = F.softmax(scores, dim=-1)
    # 加权求和
    output = torch.matmul(attn_weights, V)
    return output, attn_weights

公式为:Attention(Q,K,V) = softmax(QK^T / √d_k) · V。除以 √d_k 是为了防止点积过大导致梯度消失。

三、多头注意力与位置编码

单头注意力只能关注一种语义关系,多头注意力(Multi-Head Attention)将 Q/K/V 分别投影到 h 个子空间,并行计算后拼接,让模型同时捕捉多种特征:

import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, T, C = x.shape
        Q = self.W_q(x).view(B, T, self.num_heads, self.d_k).transpose(1,2)
        K = self.W_k(x).view(B, T, self.num_heads, self.d_k).transpose(1,2)
        V = self.W_v(x).view(B, T, self.num_heads, self.d_k).transpose(1,2)
        out, _ = scaled_dot_product_attention(Q, K, V, mask)
        out = out.transpose(1,2).contiguous().view(B, T, C)
        return self.W_o(out)

由于 Transformer 没有时序结构,需要位置编码(Positional Encoding)告知模型 token 的顺序。原始论文使用正弦/余弦位置编码;现代 LLM(如 LLaMA)则使用旋转位置编码(RoPE)以支持更长上下文。

四、工程优化:KV Cache 与 Flash Attention

在推理阶段,Transformer 的两大核心优化技术不可不知:

  • KV Cache:自回归生成时,每个新 token 不必重新计算所有历史 K/V,而是将其缓存复用,将推理复杂度从 O(n²) 降到 O(n)。代价是显存占用随序列长度线性增长,这也是大模型推理显存瓶颈的主要来源。

  • Flash Attention:标准注意力需要将完整的注意力矩阵(O(n²) 大小)写入 GPU HBM,Flash Attention 通过分块计算(Tiling)将中间结果保留在 SRAM 中,大幅减少 HBM 读写次数,速度提升 2-4 倍,显存降低 5-20 倍,是目前主流训练框架(如 vLLM、HuggingFace Transformers)的默认选项。

# 安装 Flash Attention 2
pip install flash-attn --no-build-isolation

# 在 HuggingFace 模型中启用
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3-8B",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16
)

五、从理解到实践:搭建迷你 GPT

掌握了上述原理后,可以用不到 200 行代码实现一个迷你 GPT(Decoder-only Transformer),在字符级语言模型任务上验证效果:

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Pre-LN(现代 LLM 常用)
        x = x + self.drop(self.attn(self.ln1(x), mask))
        x = x + self.drop(self.ff(self.ln2(x)))
        return x

推荐结合 Andrej Karpathy 的 nanoGPT 项目进行完整实验,该项目以最简洁的方式复现了 GPT-2,是学习 Transformer 工程实现的最佳起点。

理解 Transformer 不仅是追赶 AI 热点的必修课,更是参与大模型时代工程落地的基础门票。从注意力机制的数学本质,到 KV Cache、Flash Attention 的工程优化,每一层都值得深挖。

发布评论

热门评论区: