mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[Inference] add logit processor and request handler (#5166)
* add logit processor and request handler * add * add * add * fix * add search tokens and update func * finish request handler * add running list test * fix test * fix some bug * add * add * fix bugs * fix some bugs * fix bug * fix * fix * add copy fun * del useless attn * fix request status --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
62
colossalai/inference/sampler.py
Normal file
62
colossalai/inference/sampler.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def greedy_sample(
|
||||
generation_config,
|
||||
logprobs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample tokens greedyly.
|
||||
"""
|
||||
results = torch.argmax(logprobs, dim=-1).cpu()
|
||||
return results
|
||||
|
||||
|
||||
def multinomial_sample(
|
||||
generation_config,
|
||||
probs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample tokens in a random phase.
|
||||
"""
|
||||
max_best_of = generation_config.best_of
|
||||
random_results = torch.multinomial(probs, num_samples=max_best_of, replacement=True).cpu()
|
||||
return random_results
|
||||
|
||||
|
||||
def beam_search_sample(
|
||||
generation_config,
|
||||
logprobs: torch.Tensor,
|
||||
is_prompt: bool = False,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
"""
|
||||
Sample tokens with beam search.
|
||||
We sample 2 * beam_width candidates to make sure that with high probability we can get `beam_width` candidates in addition to
|
||||
the finished sequences for the next iteration.
|
||||
|
||||
ref:
|
||||
https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||
for details. See also HF reference:
|
||||
https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
|
||||
# NOTE: this beam search sample function is wrong now.
|
||||
"""
|
||||
|
||||
beam_width = generation_config.best_of
|
||||
results = []
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
parent_ids = [0] * (2 * beam_width)
|
||||
_, next_token_ids = torch.topk(logprobs[0], 2 * beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
# Generation phase.
|
||||
# cumulative_logprobs = [seq_data[seq_id].cumulative_logprob for seq_id in seq_ids]
|
||||
cumulative_logprobs = torch.tensor(logprobs, dtype=torch.float, device=seq_group_logprobs.device)
|
||||
seq_group_logprobs = seq_group_logprobs + cumulative_logprobs.unsqueeze(dim=1)
|
||||
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
|
||||
|
||||
results.append((next_token_ids, parent_ids))
|
||||
return results
|
Reference in New Issue
Block a user