[Fix] resolve conflicts of rebasing feat/speculative-decoding (#5557)

- resolve conflicts of rebasing feat/speculative-decoding
This commit is contained in:
Yuanheng Zhao
2024-04-07 14:53:30 +08:00
committed by ocd_with_naming
parent e1acb58423
commit e60d430cf5
6 changed files with 47 additions and 35 deletions

View File

@@ -109,13 +109,11 @@ def llama_model_forward(
# For speculative-decoding Prefill and Verifying Stage
if inputmetadata.is_prompts:
# output tensor shape is the same as normal Prefill Stage
o_tensor_size = (sequence_lengths.sum().item(), inputmetadata.num_heads * inputmetadata.head_dim)
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
else:
# the number of tokens to be verified in parallel plus the correct token in the last step
n_tokens = inputmetadata.num_tokens_to_verify + 1
assert n_tokens == hidden_states.size(0)
o_tensor_size = (batch_size * n_tokens, inputmetadata.num_heads * inputmetadata.head_dim)
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
@@ -135,15 +133,6 @@ def llama_model_forward(
else:
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
# TODO (yuanheng-zhao): revise the logic here
# if batch.is_prompts:
# output_tensor = torch.zeros(
# (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
# )
# else:
# output_tensor = torch.zeros(
# (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
# )
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
norm_output = torch.empty_like(hidden_states)
@@ -239,7 +228,6 @@ def llama_decoder_layer_forward(
sequence_lengths=sequence_lengths,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
is_prompts=is_prompts,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
sm_scale=sm_scale,