mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)
* fused the gate and up proj in mlp * fix code styles * opt auto_grad * rollback test_inference_engine.py * modifications based on the review feedback. * fix bugs in flash attn * Change reshape to view * fix test_rmsnorm_triton.py
This commit is contained in:
@@ -6,7 +6,6 @@ import torch.nn.functional as F
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||
"""
|
||||
Func: copy key/value into key/value cache.
|
||||
@@ -41,7 +40,6 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||
return cache
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
|
||||
"""
|
||||
Func: convert key/value cache for calculation
|
||||
@@ -81,7 +79,6 @@ class PagedAttention:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
|
||||
"""
|
||||
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
|
||||
@@ -97,14 +94,12 @@ class PagedAttention:
|
||||
return padded_tensor
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def generate_padding_mask(lengths, max_seq_len):
|
||||
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
|
||||
padding_mask = range_tensor < lengths.unsqueeze(1)
|
||||
return padding_mask
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||
@@ -122,7 +117,6 @@ class PagedAttention:
|
||||
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def nopad_context_forward(
|
||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
||||
@@ -191,7 +185,6 @@ class PagedAttention:
|
||||
return attn_output
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def pad_context_forward(
|
||||
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
|
||||
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
|
||||
@@ -249,7 +242,6 @@ class PagedAttention:
|
||||
return attn_output
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def pad_decoding_forward(
|
||||
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
|
||||
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
|
||||
@@ -306,7 +298,6 @@ class PagedAttention:
|
||||
return attn_output
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def no_pad_decoding_forward(
|
||||
self,
|
||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
|
Reference in New Issue
Block a user