mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Feat]Inference RPC Server Support (#5705)
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -7,10 +7,11 @@ from transformers.generation import GenerationConfig
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
||||
from colossalai.inference.struct import RequestStatus, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
__all__ = ["RunningList", "RequestHandler"]
|
||||
|
||||
@@ -295,17 +296,6 @@ class RequestHandler:
|
||||
|
||||
return None
|
||||
|
||||
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
|
||||
if generation_config.num_beams == 1:
|
||||
if generation_config.do_sample:
|
||||
sample_tokens = multinomial_sample(generation_config, probs)
|
||||
else:
|
||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||
else:
|
||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
|
||||
|
||||
return sample_tokens
|
||||
|
||||
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
||||
if (
|
||||
sequence.output_token_id[-1] == generation_config.eos_token_id
|
||||
@@ -328,33 +318,6 @@ class RequestHandler:
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
|
||||
"""
|
||||
Sample tokens for finished requests.
|
||||
"""
|
||||
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
config_dict = generation_config.to_dict()
|
||||
# process repetition_penalty, no_repeat_ngram_size
|
||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type], cur_batch)
|
||||
|
||||
# do logit processor
|
||||
if generation_config.do_sample:
|
||||
# process temperature, top_k, top_p
|
||||
for type in ["temperature", "top_k", "top_p"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type])
|
||||
|
||||
# calculate probs
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# sample the next tokens
|
||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||
return sample_tokens
|
||||
|
||||
def append_next_tokens(self, sample_tokens: torch.Tensor):
|
||||
assert sample_tokens.dim() == 1
|
||||
n_elements = sample_tokens.size(0)
|
||||
@@ -386,3 +349,53 @@ class RequestHandler:
|
||||
self.done_list.extend(finished_seqs)
|
||||
|
||||
return finished_seqs
|
||||
|
||||
|
||||
class RPCRequestHandler(RequestHandler):
|
||||
"""
|
||||
RPC Version of request handler
|
||||
"""
|
||||
|
||||
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
||||
self.waiting_list: List[List] = [[], [], []]
|
||||
self.done_list: List[Sequence] = []
|
||||
self.dtype = inference_config.dtype
|
||||
self.max_batch_size = inference_config.max_batch_size
|
||||
|
||||
# initialize cache
|
||||
self._init_cache(model_config)
|
||||
|
||||
# initialize batch
|
||||
torch.cuda.current_device()
|
||||
kv_max_split_num = (
|
||||
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
|
||||
) // inference_config.block_size
|
||||
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
||||
|
||||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||
# which may cause bugs and this issue should be fixed later.
|
||||
self.running_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
block_size=inference_config.block_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
fd_interm_tensor=None,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.prefill_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
block_size=inference_config.block_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
fd_interm_tensor=None,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = RPCKVCacheManager(self.inference_config, model_config)
|
||||
|
Reference in New Issue
Block a user