mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[Feat]Inference RPC Server Support (#5705)
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
|
||||
_LOGIT_PROCESSOR_MAP = {}
|
||||
|
||||
|
||||
@@ -22,7 +21,7 @@ def register_logit_processor(process_type):
|
||||
|
||||
|
||||
@register_logit_processor("no_repeat_ngram_size")
|
||||
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket):
|
||||
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]):
|
||||
"""
|
||||
enforces no repetition of n-grams to avoid repetitions of word sequences.
|
||||
"""
|
||||
@@ -31,7 +30,6 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck
|
||||
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
|
||||
|
||||
if ngram_size != 0:
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
batch_size = len(batch_token_ids)
|
||||
|
||||
for batch_id in range(batch_size):
|
||||
@@ -55,7 +53,7 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck
|
||||
|
||||
|
||||
@register_logit_processor("repetition_penalty")
|
||||
def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket):
|
||||
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]):
|
||||
"""
|
||||
apply the penalty to the tokens present in the prompt.
|
||||
"""
|
||||
@@ -67,7 +65,6 @@ def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket)
|
||||
|
||||
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
|
||||
if penalty != 1.0:
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
for batch_id in range(len(batch_token_ids)):
|
||||
current_logit = logits[batch_id]
|
||||
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)
|
||||
|
Reference in New Issue
Block a user