https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism
上一章讲到了普通的数据层面的并行,这一章来看看Tensor Parallelism和Context Parallelism。
Tensor Parallelism(张量并行)
用到了一点神奇的线性代数知识。
复习一下矩阵乘法
那么对A的每一行/每一列,都可以拆出来和B单独做乘法,对B同理。
根据上面的说法,我们可以把在其中完整的张量拆分。分为行拆分和列拆分。
在Transformer中,同样是有两种内容:sequence和parameters,那么对于两种内容我们也可以做出拆分。sequence是输入值,对应上图的X,parameters则是上面linear module中的W和B。
Transformer块中parameters的Tensor Parallelism
又是新知识点!关于Transformer块的结构。当然这里不会仔细讲的,去问问GPT吧!他会告诉你的
Transformer:
左侧的Encoder,处理输入的原序列变成一个向量,其中每列都是one hot编码,然后加入位置信息(Positional Encoding)。接下来通过N层编码器层(多头自注意力+前馈神经网络)。负责提供上下文。
右侧的Decoder,根据期望输出的内容(或者已经predict出来的上下文),同样编码,经过解码器层(这里是带掩码的多头自注意力 + 多头交叉注意力 + 前馈网络),结合编码器给出的源序列,给出一个各个token出现的概率,最后通过softmax或者类似的(比如随机选择)输出接下来的Sequence。
注意力机制:
注意力机制,大概类似人类认知过程:看一张图片/一段文字时,我们会更关注重要的点,且忽略不重要的部分。
举个例子:The animal didn’t cross the street because it was too tired。在自注意力机制的计算下,it相关联的词是The animal。
为了计算注意力,模型会给每个词向量创建三个向量:查询(Q),键(K),值(V)。
Query:我要关注什么东西?
Key:这个东西有什么比较重要的地方(特征?)
Value:实际内容。
计算方法:⁍
- 计算分数:当前的Q向量与其他词的K向量点积求相关性(余弦公式知道吧?相关性越大,点积越接近1)
- 缩放并归一化,最后求和。
得到的向量是当前输入中更重要的单词。
这个地方还有个修饰词多头:就理解成多个专家,负责处理不同的注意力部分。(还真是多头)
前馈神经网络:传统MLP,非线性变换+Activation。
好了,知识点补充完成,我们继续看这个parameters的并行。
- 在注意力部分:Q,K,V采用列分片,每个GPU负责部分注意力头。输出投影可用行分片。最后合并到一起计算
- 限制:TP并行度不超过头数。
- 在前馈部分,“列分片 + 行分片”效率更高
这里,通信效率仍然可能是瓶颈。单节点(8卡)内TP通信较快,TP度数越高,单卡吞吐下降越多,但是可以支持更大的batch size和模型规模。
内存节省方面,显著降低,可以让大模型在有限GPU上训练,但是后续操作(Dropout,LayerNorm)仍需全量激活,还可以优化!
Transformer的序列并行
根据上面的内容我们知道,sequence的内存增加是平方级别的。(因为对于一个长度为L的序列,在注意力机制中需要L^2个位置来存放注意力分数(对于每个token我们需要评估他和其他所有token的相关性)所以提出了序列并行。
核心思想是把序列分段。主要需要解决的难点是如何在序列被切分的情况下,高效地完成全局注意力计算。
在注意力层中,考虑序列分段到各个GPU上的场景,数据科学家们提出的做法如下:
- 计算局部的Q,K,V。
- All-Gather来获取全局的K和V。
- 局部的Q和全局的K,V计算得到局部序列注意力输出,然后把全局K,V drop掉。
- Reduce-Scatter把局部序列注意力相加,然后再次分割发回GPU。
而其他层(Dropout、LayerNorm),需要解决TP留下的”历史遗留“:LayerNorm和Dropout没有需要分割的部分,并且计算冗余。
完美的特性——计算是局部的:对序列中的第 i 个 Token 进行 LayerNorm,只需要这个 Token 自身的 h 维向量。它完全不需要序列中第 j 个 Token 的任何信息。这意味着: 一旦输入张量沿着序列维度被切分,每张 GPU 就可以在自己的数据分片上独立、完整地执行 LayerNorm 操作,而完全不需要和其他 GPU 进行任何通信。
效果:
局限性和优化:通信瓶颈仍然存在,实现逻辑较为复杂,部分层不适用。
Context Parallelism(上下文并行)
当上下文长度超长(128K + )时,Sequence Parallelism也无法处理Activation Value等的增加。所以引入新策略——Ring Attention
Ring Attention是一种高效的通信方式:每个GPU异步发送自己的key/value到下一个GPU,同时计算本地部分的注意力分数,循环进行,最终完成全序列的注意力计算。
这种方式虽然高效,但在因果注意力(causal attention)下,计算负载可能不均衡,需要进一步优化(如Zig-Zag Ring Attention)。
Zig-Zag Ring Attention 不是简单地顺序分配token到各GPU,而是将早期和晚期token交错分配,使每个GPU都能处理不同位置的token,从而均衡了各GPU的计算量。
这种分配方式让注意力掩码(attention mask)下的计算任务在所有GPU间分布更均匀,避免了某些GPU计算负载过重、某些过轻的问题。
处理流程:
- 序列切分:长度切分
- 局部计算:每个GPU接收到自己的子序列后,独立地、并行地计算该子序列对应的查询(Query)、键(Key)和值(Value)。
- 全局信息同步:为了让每个GPU都能计算完整的注意力(即每个Query都能注意到全部的Key),需要通信:all-gather,将自己计算出的局部KV Cache广播给所有其他GPU。执行完 all-gather 后,每个GPU上就都有了完整的、来自全部序列的KV Cache。
- 注意力计算:局部Query与全局的Key/Value进行注意力计算。