mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
add context_attention_unpadded
This commit is contained in:
committed by
FrankLeeeee
parent
07b5283b6a
commit
02c1bf8b2a
@@ -232,11 +232,7 @@ class InferenceEngine:
|
||||
|
||||
# Decode completed sentences.
|
||||
for seq in finished_sequences:
|
||||
if seq.prompt:
|
||||
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(seq.prompt + output_str)
|
||||
else:
|
||||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(output_str)
|
||||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(output_str)
|
||||
|
||||
return output_list
|
||||
|
@@ -156,9 +156,9 @@ class RequestHandler:
|
||||
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config):
|
||||
if generation_config.num_beams == 1:
|
||||
if generation_config.do_sample:
|
||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||
else:
|
||||
sample_tokens = multinomial_sample(generation_config, probs)
|
||||
else:
|
||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||
else:
|
||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)
|
||||
|
||||
|
@@ -5,6 +5,7 @@ import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
||||
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import context_attention_unpadded
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@@ -53,7 +54,6 @@ def llama_causal_lm_forward(
|
||||
v_caches=v_caches,
|
||||
)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@@ -157,15 +157,17 @@ def llama_attn_forward(
|
||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
# TODO: The code below will be uncommented after the development of attention-related kernel is completed.
|
||||
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
|
||||
# if is_prompts:
|
||||
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
# else:
|
||||
# attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
|
||||
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
|
||||
_, _, _, block_size = k_cache.shape
|
||||
|
||||
# NOTE: context_attention_unpadded is unsed for testing accuracy and we can only use aligned inputs.
|
||||
# The code below will be uncommented after the development of attention-related kernel is completed.
|
||||
if is_prompts:
|
||||
attn_output = context_attention_unpadded(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
)
|
||||
# else:
|
||||
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
|
||||
attn_output = query_states
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
@@ -21,7 +21,6 @@ def multinomial_sample(
|
||||
"""
|
||||
Sample tokens in a random phase.
|
||||
"""
|
||||
# max_best_of = generation_config.best_of
|
||||
random_results = torch.multinomial(probs, num_samples=1, replacement=True).cpu()
|
||||
return random_results
|
||||
|
||||
|
Reference in New Issue
Block a user