[triton] added copyright information for flash attention (#2835)

* [triton] added copyright information for flash attention

* polish code
This commit is contained in:
Frank Lee
2023-02-21 11:25:57 +08:00
committed by GitHub
parent 7ea6bc7f69
commit 918bc94b6b
2 changed files with 27 additions and 8 deletions

View File

@@ -1,8 +1,12 @@
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
The triton-based flash attention implementation is copied from the OpenAI/triton repository
You can find the repository in Triton https://github.com/openai/triton
You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
Reference:
1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
"""
import math
@@ -56,7 +60,8 @@ except ImportError:
print('please install xformers from https://github.com/facebookresearch/xformers')
if HAS_TRITON:
# the following functions are adapted from the OpenAI Triton tutorial
# https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
@triton.jit
def _fwd_kernel(
Q,