mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Feat]Inference RPC Server Support (#5705)
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
|
||||
|
||||
def greedy_sample(
|
||||
@@ -59,3 +62,47 @@ def beam_search_sample(
|
||||
|
||||
results.append((next_token_ids, parent_ids))
|
||||
return results
|
||||
|
||||
|
||||
def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False):
|
||||
if generation_config.num_beams == 1:
|
||||
if generation_config.do_sample:
|
||||
sample_tokens = multinomial_sample(generation_config, probs)
|
||||
else:
|
||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||
else:
|
||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt)
|
||||
|
||||
return sample_tokens
|
||||
|
||||
|
||||
def search_tokens(
|
||||
generation_config: GenerationConfig,
|
||||
logits,
|
||||
is_prompt: bool = False,
|
||||
batch_token_ids: Optional[List[List[int]]] = None,
|
||||
):
|
||||
"""
|
||||
Sample tokens for finished requests.
|
||||
"""
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
config_dict = generation_config.to_dict()
|
||||
# process repetition_penalty, no_repeat_ngram_size
|
||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type], batch_token_ids)
|
||||
|
||||
# do logit processor
|
||||
if generation_config.do_sample:
|
||||
# process temperature, top_k, top_p
|
||||
for type in ["temperature", "top_k", "top_p"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type])
|
||||
|
||||
# calculate probs
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# sample the next tokens
|
||||
sample_tokens = _sample(probs, logprobs, generation_config, is_prompt)
|
||||
return sample_tokens
|
||||
|
Reference in New Issue
Block a user