[Feat]Inference RPC Server Support (#5705)

* rpc support source
* kv cache logical/physical disaggregation
* sampler refactor
* colossalai launch built in
* Unitest
* Rpyc support

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Runyu Lu
2024-05-14 10:00:55 +08:00
committed by GitHub
parent de4bf3dedf
commit 18d67d0e8e
15 changed files with 1032 additions and 63 deletions

View File

@@ -1,6 +1,9 @@
from typing import List, Tuple
from typing import List, Optional, Tuple
import torch
from transformers.generation import GenerationConfig
from colossalai.inference.logit_processors import logit_processor
def greedy_sample(
@@ -59,3 +62,47 @@ def beam_search_sample(
results.append((next_token_ids, parent_ids))
return results
def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = multinomial_sample(generation_config, probs)
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt)
return sample_tokens
def search_tokens(
generation_config: GenerationConfig,
logits,
is_prompt: bool = False,
batch_token_ids: Optional[List[List[int]]] = None,
):
"""
Sample tokens for finished requests.
"""
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type], batch_token_ids)
# do logit processor
if generation_config.do_sample:
# process temperature, top_k, top_p
for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# sample the next tokens
sample_tokens = _sample(probs, logprobs, generation_config, is_prompt)
return sample_tokens