mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[Fix] resolve conflicts of rebasing feat/speculative-decoding (#5557)
- resolve conflicts of rebasing feat/speculative-decoding
This commit is contained in:
committed by
ocd_with_naming
parent
e1acb58423
commit
e60d430cf5
@@ -109,13 +109,11 @@ def llama_model_forward(
|
||||
# For speculative-decoding Prefill and Verifying Stage
|
||||
if inputmetadata.is_prompts:
|
||||
# output tensor shape is the same as normal Prefill Stage
|
||||
o_tensor_size = (sequence_lengths.sum().item(), inputmetadata.num_heads * inputmetadata.head_dim)
|
||||
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
|
||||
else:
|
||||
# the number of tokens to be verified in parallel plus the correct token in the last step
|
||||
n_tokens = inputmetadata.num_tokens_to_verify + 1
|
||||
assert n_tokens == hidden_states.size(0)
|
||||
o_tensor_size = (batch_size * n_tokens, inputmetadata.num_heads * inputmetadata.head_dim)
|
||||
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
|
||||
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
|
||||
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||
@@ -135,15 +133,6 @@ def llama_model_forward(
|
||||
else:
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
||||
|
||||
# TODO (yuanheng-zhao): revise the logic here
|
||||
# if batch.is_prompts:
|
||||
# output_tensor = torch.zeros(
|
||||
# (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
# )
|
||||
# else:
|
||||
# output_tensor = torch.zeros(
|
||||
# (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
# )
|
||||
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
||||
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
@@ -239,7 +228,6 @@ def llama_decoder_layer_forward(
|
||||
sequence_lengths=sequence_lengths,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
is_prompts=is_prompts,
|
||||
kv_seq_len=kv_seq_len,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
|
Reference in New Issue
Block a user