add context_attention_unpadded

This commit is contained in:
yuehuayingxueluo
2024-01-03 18:50:26 +08:00
committed by FrankLeeeee
parent 07b5283b6a
commit 02c1bf8b2a
5 changed files with 37 additions and 29 deletions

View File

@@ -232,11 +232,7 @@ class InferenceEngine:
# Decode completed sentences.
for seq in finished_sequences:
if seq.prompt:
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
output_list.append(seq.prompt + output_str)
else:
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
output_list.append(output_str)
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
output_list.append(output_str)
return output_list

View File

@@ -156,9 +156,9 @@ class RequestHandler:
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = multinomial_sample(generation_config, probs)
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)