完全分片数据并行 (Fully Sharded Data Parallelism,FSDP) 是一种训练范式,在该范式中优化器状态、梯度和模型参数都会被跨设备分片。前向传播时,每个 FSDP 单元执行
all gather 以获取完整的权重,然后用它们进行计算并在计算后丢弃掉其他设备的分片。随后是反向传播,然后就是损失计算。反向传播时,每个 FSDP 单元执行
all gather 操作以获取完整的权重,并执行计算以获得本地 batch 的梯度。这些梯度通过
reduce scatter 在设备上进行均值计算并分片,这样每个设备都可以更新其对应分片的参数。有关 PyTorch FSDP 的更多信息,请参阅此博文:
使用 PyTorch 完全分片数据并行技术加速大模型训练