Logo
Overview

GPU 集群训练优化 阅读笔记(二)

July 11, 2025

https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism

上一章讲到了普通的数据层面的并行,这一章来看看Tensor Parallelism和Context Parallelism。

Tensor Parallelism(张量并行)

用到了一点神奇的线性代数知识。

💡

复习一下矩阵乘法

那么对A的每一行/每一列,都可以拆出来和B单独做乘法,对B同理。

根据上面的说法,我们可以把在其中完整的张量拆分。分为行拆分和列拆分。

原始表达
列拆分(我们可以把W放在不同的GPU上进行计算)
行拆分,甚至还能把X拆出去

在Transformer中,同样是有两种内容:sequence和parameters,那么对于两种内容我们也可以做出拆分。sequence是输入值,对应上图的X,parameters则是上面linear module中的W和B。

Transformer块中parameters的Tensor Parallelism
💡

又是新知识点!关于Transformer块的结构。当然这里不会仔细讲的,去问问GPT吧!他会告诉你的

Transformer

来自Wikipedia的一个典型Transformer块,来源

左侧的Encoder,处理输入的原序列变成一个向量,其中每列都是one hot编码,然后加入位置信息(Positional Encoding)。接下来通过N层编码器层(多头自注意力+前馈神经网络)。负责提供上下文。

右侧的Decoder,根据期望输出的内容(或者已经predict出来的上下文),同样编码,经过解码器层(这里是带掩码的多头自注意力 + 多头交叉注意力 + 前馈网络),结合编码器给出的源序列,给出一个各个token出现的概率,最后通过softmax或者类似的(比如随机选择)输出接下来的Sequence。

注意力机制

注意力机制,大概类似人类认知过程:看一张图片/一段文字时,我们会更关注重要的点,且忽略不重要的部分。

举个例子:The animal didn’t cross the street because it was too tired。在自注意力机制的计算下,it相关联的词是The animal。

为了计算注意力,模型会给每个词向量创建三个向量:查询(Q),键(K),值(V)。

Query:我要关注什么东西?

Key:这个东西有什么比较重要的地方(特征?)

Value:实际内容。

计算方法:⁍

  • 计算分数:当前的Q向量与其他词的K向量点积求相关性(余弦公式知道吧?相关性越大,点积越接近1)
  • 缩放并归一化,最后求和。

得到的向量是当前输入中更重要的单词。

这个地方还有个修饰词多头:就理解成多个专家,负责处理不同的注意力部分。(还真是多头)

前馈神经网络:传统MLP,非线性变换+Activation。

好了,知识点补充完成,我们继续看这个parameters的并行。

  • 在注意力部分:Q,K,V采用列分片,每个GPU负责部分注意力头。输出投影可用行分片。最后合并到一起计算
    • 限制:TP并行度不超过头数。
  • 在前馈部分,“列分片 + 行分片”效率更高

这里,通信效率仍然可能是瓶颈。单节点(8卡)内TP通信较快,TP度数越高,单卡吞吐下降越多,但是可以支持更大的batch size和模型规模。

内存节省方面,显著降低,可以让大模型在有限GPU上训练,但是后续操作(Dropout,LayerNorm)仍需全量激活,还可以优化!

可以看到,模型的Parameters内存占用显著降低。但是随着Sequence的增加,内存占用还是在增加。
Transformer的序列并行

根据上面的内容我们知道,sequence的内存增加是平方级别的。(因为对于一个长度为L的序列,在注意力机制中需要L^2个位置来存放注意力分数(对于每个token我们需要评估他和其他所有token的相关性)所以提出了序列并行。

核心思想是把序列分段。主要需要解决的难点是如何在序列被切分的情况下,高效地完成全局注意力计算。

在注意力层中,考虑序列分段到各个GPU上的场景,数据科学家们提出的做法如下:

  • 计算局部的Q,K,V。
  • All-Gather来获取全局的K和V。
  • 局部的Q和全局的K,V计算得到局部序列注意力输出,然后把全局K,V drop掉。
  • Reduce-Scatter把局部序列注意力相加,然后再次分割发回GPU。

而其他层(Dropout、LayerNorm),需要解决TP留下的”历史遗留“:LayerNorm和Dropout没有需要分割的部分,并且计算冗余。

完美的特性——计算是局部的:对序列中的第 i 个 Token 进行 LayerNorm,只需要这个 Token 自身的 h 维向量。它完全不需要序列中第 j 个 Token 的任何信息。这意味着: 一旦输入张量沿着序列维度被切分,每张 GPU 就可以在自己的数据分片上独立、完整地执行 LayerNorm 操作,而完全不需要和其他 GPU 进行任何通信。

效果:

对比上一张图,内存占用也显著降低了。

局限性和优化:通信瓶颈仍然存在,实现逻辑较为复杂,部分层不适用。

Context Parallelism(上下文并行)

当上下文长度超长(128K + )时,Sequence Parallelism也无法处理Activation Value等的增加。所以引入新策略——Ring Attention

Ring Attention是一种高效的通信方式:每个GPU异步发送自己的key/value到下一个GPU,同时计算本地部分的注意力分数,循环进行,最终完成全序列的注意力计算。

这种方式虽然高效,但在因果注意力(causal attention)下,计算负载可能不均衡,需要进一步优化(如Zig-Zag Ring Attention)。

Zig-Zag Ring Attention 不是简单地顺序分配token到各GPU,而是将早期和晚期token交错分配,使每个GPU都能处理不同位置的token,从而均衡了各GPU的计算量。

这种分配方式让注意力掩码(attention mask)下的计算任务在所有GPU间分布更均匀,避免了某些GPU计算负载过重、某些过轻的问题。

处理流程:

  • 序列切分:长度切分
  • 局部计算:每个GPU接收到自己的子序列后,独立地、并行地计算该子序列对应的查询(Query)、键(Key)和值(Value)。
  • 全局信息同步:为了让每个GPU都能计算完整的注意力(即每个Query都能注意到全部的Key),需要通信:all-gather,将自己计算出的局部KV Cache广播给所有其他GPU。执行完 all-gather 后,每个GPU上就都有了完整的、来自全部序列的KV Cache。
  • 注意力计算:局部Query与全局的Key/Value进行注意力计算。

comment

留言 / 评论

如果暂时没有看到评论,请点击下方按钮重新加载。