[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)

* fused the gate and up proj in mlp

* fix code styles

* opt auto_grad

* rollback test_inference_engine.py

* modifications based on the review feedback.

* fix bugs in flash attn

* Change reshape to view

* fix test_rmsnorm_triton.py
This commit is contained in:
yuehuayingxueluo
2024-02-06 19:38:25 +08:00
committed by GitHub
parent 1dedb57747
commit 35382a7fbf
10 changed files with 484 additions and 50 deletions

View File

@@ -6,7 +6,6 @@ import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@torch.no_grad
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
"""
Func: copy key/value into key/value cache.
@@ -41,7 +40,6 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
return cache
@torch.no_grad
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
"""
Func: convert key/value cache for calculation
@@ -81,7 +79,6 @@ class PagedAttention:
"""
@staticmethod
@torch.no_grad
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
"""
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
@@ -97,14 +94,12 @@ class PagedAttention:
return padded_tensor
@staticmethod
@torch.no_grad
def generate_padding_mask(lengths, max_seq_len):
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
padding_mask = range_tensor < lengths.unsqueeze(1)
return padding_mask
@staticmethod
@torch.no_grad
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
"""
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
@@ -122,7 +117,6 @@ class PagedAttention:
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)
@staticmethod
@torch.no_grad
def nopad_context_forward(
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
@@ -191,7 +185,6 @@ class PagedAttention:
return attn_output
@staticmethod
@torch.no_grad
def pad_context_forward(
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
@@ -249,7 +242,6 @@ class PagedAttention:
return attn_output
@staticmethod
@torch.no_grad
def pad_decoding_forward(
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
@@ -306,7 +298,6 @@ class PagedAttention:
return attn_output
@staticmethod
@torch.no_grad
def no_pad_decoding_forward(
self,
q: torch.Tensor, # [num_tokens, num_heads, head_size]