介绍

FlashAttention V1 是 Transformer 领域里一个非常重要的优化工作,由 Tri Dao 等人在 2022 年提出。它的核心目标是:

不改变 Attention 数学结果(Exact Attention)的前提下,大幅降低显存访问(IO)开销。

为什么需要FlashAttention

在传统 Attention 的计算流程中,需要先生成一个大小为 N×N(N 为序列长度)的巨大注意力矩阵。这个矩阵是计算瓶颈,因为它太大,无法完全放入 GPU 的高速缓存(SRAM)中,只能反复在较慢的全局显存(HBM)和 SRAM 之间搬运,导致大量时间浪费在等待数据上。

因此传统 Attention 的瓶颈很多时候不是算力(FLOPs),而是:

  • GPU HBM(显存)读写太多

  • 中间矩阵太大

  • attention score 矩阵反复搬运

传统的Attention计算如下:

上图中一共包含八次HBM的矩阵读写操作。这八次读写操作分别为:

  • 第一行对Q,KQ,K的读取共两次,对SS的写入一次,读写总共三次;

  • 第二行对SS读取一次,对PP写入一次,读写总共两次;

  • 第一行对P,VP,V的读取共两次,对VV的写入一次,读写总共三次。

FlashAttention 的核心原理

为解决上述问题,FlashAttention 采用了两个关键技术:

  1. 分块计算:不一次性计算整个注意力矩阵,而是将 Q、K、V 矩阵切分成许多小块。每次只把一个小块加载到高速的 SRAM 中,计算完这部分结果后,再加载下一块。这样就避免了在 HBM 中生成巨大的 N×N 矩阵。

  2. 重计算:在反向传播(即模型训练中计算梯度的过程)时,很多中间结果本需要保存下来。FlashAttention 选择不保存这些占用大量显存的中间矩阵,而是在需要时,利用高速 SRAM 重新快速计算一遍。这虽然增加了一点计算量,但大幅降低了显存占用和读写开销。

Online Softmax

我这里应该先从online softmax讲起,但我懒得写了,看看其他人的文章吧~

FlashAttention V1

Flash Attention V2

下面是FA V2的伪代码: