mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[triton] added copyright information for flash attention (#2835)
* [triton] added copyright information for flash attention * polish code
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
|
||||
The triton-based flash attention implementation is copied from the OpenAI/triton repository
|
||||
|
||||
You can find the repository in Triton https://github.com/openai/triton
|
||||
You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
|
||||
|
||||
Reference:
|
||||
1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
|
||||
2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
|
||||
"""
|
||||
|
||||
import math
|
||||
@@ -56,7 +60,8 @@ except ImportError:
|
||||
print('please install xformers from https://github.com/facebookresearch/xformers')
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
# the following functions are adapted from the OpenAI Triton tutorial
|
||||
# https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
|
Reference in New Issue
Block a user