add context_attention_unpadded

This commit is contained in:
yuehuayingxueluo
2024-01-03 18:50:26 +08:00
committed by FrankLeeeee
parent 07b5283b6a
commit 02c1bf8b2a
5 changed files with 37 additions and 29 deletions

View File

@@ -156,9 +156,9 @@ class RequestHandler:
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = multinomial_sample(generation_config, probs)
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)