
Flash Attention 是一种优化的注意力机制, 旨在提高深度学习模型中注意力计算的效率. 它通优化访存机制来加速训练和推理过程.
目前的GPU架构中, HBM 容量大但处理速度慢, SRAM 虽然容量小但操作速度快.
标准的注意力机制使用 HBM 来存储、读取和写入注意力分数矩阵(attention score matrix, 矩阵存储 Q/K/V). 具体步骤为将这些从 HBM 加载到 GPU 的片上 SRAM, 然后执行注意力机制的单个步骤, 然后写回 HBM, 并重复此过程.
而 Flash Attention 则是采用分块计算(Tiling)技术,将大型注意力矩阵划分为多个块(tile),在 SRAM 中逐块执行计算。通过:
总的来说, Flash Attention 是一种强大的工具, 能够在不牺牲性能的情况下提高模型的效率, 但在实现和使用时需要考虑其复杂性和硬件要求.

当使用 H100 显卡且序列长度是512时(数据来自论文测试),PyTorch 的标准处理速度是 62 Tflops,而 Flash Attention 则可以达到 157 Tflops,Flash Attention 2 则可以达到215 Tflops。在 FP16/BF16 精度下,实际加速比可达标准实现的 3-4 倍。