[Inference/Kernel]Add get_cos_and_sin Kernel (#5528)

* Add get_cos_and_sin kernel

* fix code comments

* fix code typos

* merge common codes of get_cos_and_sin kernel.

* Fixed a typo

* Changed 'asset allclose' to 'assert equal'.
This commit is contained in:
yuehuayingxueluo
2024-04-01 13:47:14 +08:00
committed by GitHub
parent 934e31afb2
commit 04aca9e55b
5 changed files with 295 additions and 6 deletions

View File

@@ -101,12 +101,22 @@ def llama_model_forward(
use_cuda_kernel = False
hidden_states = self.embed_tokens(input_tokens_ids)
if use_cuda_kernel and inputmetadata != torch.float32 and use_flash_attn2:
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
if use_cuda_kernel:
if inputmetadata != torch.float32 and use_flash_attn2:
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
hidden_dim = self._cos_cached.size(-1)
total_length = hidden_states.size(0)
cos = torch.empty((total_length, hidden_dim), dtype=self._cos_cached.dtype, device=self._cos_cached.device)
sin = torch.empty((total_length, hidden_dim), dtype=self._sin_cached.dtype, device=self._sin_cached.device)
inference_ops.get_cos_and_sin(
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
)
cos_sin = (cos, sin)
else:
cu_seqlens = None
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)