mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-11 08:58:34 +00:00
[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
|
||||
"""
|
||||
|
||||
import torch
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
|
||||
except ImportError:
|
||||
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||
|
||||
|
||||
|
||||
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len):
|
||||
"""
|
||||
Arguments:
|
||||
qkv: (batch*seq, 3, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
"""
|
||||
max_s = seq_len
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
out = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_s, 0.0,
|
||||
softmax_scale=sm_scale, causal=False
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch*seq, nheads, headdim)
|
||||
kv: (batch*seq, 2, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
"""
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
|
||||
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale)
|
||||
return out
|
Reference in New Issue
Block a user