[Feat]Inference RPC Server Support (#5705)

* rpc support source
* kv cache logical/physical disaggregation
* sampler refactor
* colossalai launch built in
* Unitest
* Rpyc support

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Runyu Lu
2024-05-14 10:00:55 +08:00
committed by GitHub
parent de4bf3dedf
commit 18d67d0e8e
15 changed files with 1032 additions and 63 deletions

View File

@@ -1,10 +1,9 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
from typing import List
import torch
import torch.nn.functional as F
from colossalai.inference.batch_bucket import BatchBucket
_LOGIT_PROCESSOR_MAP = {}
@@ -22,7 +21,7 @@ def register_logit_processor(process_type):
@register_logit_processor("no_repeat_ngram_size")
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket):
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]):
"""
enforces no repetition of n-grams to avoid repetitions of word sequences.
"""
@@ -31,7 +30,6 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
if ngram_size != 0:
batch_token_ids = batch.batch_token_ids
batch_size = len(batch_token_ids)
for batch_id in range(batch_size):
@@ -55,7 +53,7 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck
@register_logit_processor("repetition_penalty")
def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket):
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]):
"""
apply the penalty to the tokens present in the prompt.
"""
@@ -67,7 +65,6 @@ def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket)
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
if penalty != 1.0:
batch_token_ids = batch.batch_token_ids
for batch_id in range(len(batch_token_ids)):
current_logit = logits[batch_id]
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)