mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -13,9 +12,10 @@ if HAS_TRITON:
|
||||
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
from .softmax import softmax_kernel
|
||||
|
||||
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
input_mask: torch.Tensor, scale: float):
|
||||
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
|
||||
def self_attention_forward_without_fusion(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float
|
||||
):
|
||||
r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
|
||||
Args:
|
||||
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
|
||||
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
|
||||
@@ -65,7 +65,7 @@ if HAS_TRITON:
|
||||
score_output.stride(2),
|
||||
score_output.stride(3),
|
||||
scale=scale,
|
||||
# currently manually setting, later on we can use auto-tune config to match best setting
|
||||
# currently manually setting, later on we can use auto-tune config to match best setting
|
||||
BLOCK_SIZE_M=64,
|
||||
BLOCK_SIZE_N=32,
|
||||
BLOCK_SIZE_K=32,
|
||||
@@ -79,7 +79,6 @@ if HAS_TRITON:
|
||||
n_rows, n_cols = score_output.shape
|
||||
|
||||
if n_rows <= 350000:
|
||||
|
||||
block_size = max(triton.next_power_of_2(n_cols), 2)
|
||||
num_warps = 4
|
||||
if block_size >= 4096:
|
||||
@@ -142,15 +141,9 @@ if HAS_TRITON:
|
||||
)
|
||||
return output.view(batches, -1, d_model)
|
||||
|
||||
def self_attention_compute_using_triton(qkv,
|
||||
input_mask,
|
||||
layer_past,
|
||||
alibi,
|
||||
scale,
|
||||
head_size,
|
||||
triangular=False,
|
||||
use_flash=False):
|
||||
|
||||
def self_attention_compute_using_triton(
|
||||
qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False
|
||||
):
|
||||
assert qkv.is_contiguous()
|
||||
assert alibi is None, "current triton self-attention does not support alibi"
|
||||
batches = qkv.shape[0]
|
||||
@@ -158,8 +151,8 @@ if HAS_TRITON:
|
||||
num_of_heads = d_model // head_size
|
||||
|
||||
q = qkv[:, :, :d_model]
|
||||
k = qkv[:, :, d_model:d_model * 2]
|
||||
v = qkv[:, :, d_model * 2:]
|
||||
k = qkv[:, :, d_model : d_model * 2]
|
||||
v = qkv[:, :, d_model * 2 :]
|
||||
q = q.view(batches, -1, num_of_heads, head_size)
|
||||
k = k.view(batches, -1, num_of_heads, head_size)
|
||||
v = v.view(batches, -1, num_of_heads, head_size)
|
||||
|
Reference in New Issue
Block a user