mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[Inference/Kernel]Add get_cos_and_sin Kernel (#5528)
* Add get_cos_and_sin kernel * fix code comments * fix code typos * merge common codes of get_cos_and_sin kernel. * Fixed a typo * Changed 'asset allclose' to 'assert equal'.
This commit is contained in:
@@ -101,12 +101,22 @@ def llama_model_forward(
|
||||
use_cuda_kernel = False
|
||||
|
||||
hidden_states = self.embed_tokens(input_tokens_ids)
|
||||
if use_cuda_kernel and inputmetadata != torch.float32 and use_flash_attn2:
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
if use_cuda_kernel:
|
||||
if inputmetadata != torch.float32 and use_flash_attn2:
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
hidden_dim = self._cos_cached.size(-1)
|
||||
total_length = hidden_states.size(0)
|
||||
cos = torch.empty((total_length, hidden_dim), dtype=self._cos_cached.dtype, device=self._cos_cached.device)
|
||||
sin = torch.empty((total_length, hidden_dim), dtype=self._sin_cached.dtype, device=self._sin_cached.device)
|
||||
inference_ops.get_cos_and_sin(
|
||||
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
|
||||
)
|
||||
cos_sin = (cos, sin)
|
||||
|
||||
else:
|
||||
cu_seqlens = None
|
||||
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
||||
|
||||
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
||||
|
||||
|
Reference in New Issue
Block a user