mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Feat]Inference RPC Server Support (#5705)
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -2,11 +2,11 @@
|
||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||
"""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
@@ -30,8 +30,25 @@ _DEFAULT_PROMPT_TEMPLATES = {
|
||||
}
|
||||
|
||||
|
||||
class RPC_PARAM(ABC):
|
||||
"""
|
||||
NOTE(lry89757) We use rpyc to transport param between client and server.
|
||||
Rpyc only support the type of `POD` in python as the param, so we should take some smart ways to transport the data like tensor or some sophisticated classes.
|
||||
Drawing on the logic of `__setstate__`, `__getstate__`, we will let some classes(will be rpc param later) inherit this base class, and rewrite the to_rpc_param and from_rpc_param. We will invoke `to_rpc_param` in client to pass the params and recover the param in server side by `from_rpc_param`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to_rpc_param(self):
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def from_rpc_param():
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetaData:
|
||||
class InputMetaData(RPC_PARAM):
|
||||
"""The input info for a single step
|
||||
|
||||
Args:
|
||||
@@ -48,6 +65,7 @@ class InputMetaData:
|
||||
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
|
||||
use_spec_dec (bool): Indicate whether to use speculative decoding.
|
||||
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
|
||||
batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process.
|
||||
"""
|
||||
|
||||
block_tables: torch.Tensor = None
|
||||
@@ -63,6 +81,54 @@ class InputMetaData:
|
||||
dtype: torch.dtype = torch.float32
|
||||
use_spec_dec: bool = False
|
||||
num_tokens_to_verify: int = 0
|
||||
batch_token_ids: Optional[
|
||||
List[List[int]]
|
||||
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
|
||||
|
||||
def to_rpc_param(self) -> Dict[str, any]:
|
||||
return {
|
||||
"block_tables": self.block_tables.tolist(),
|
||||
"sequence_lengths": self.sequence_lengths.tolist(),
|
||||
"batch_size": self.batch_size,
|
||||
"is_prompts": self.is_prompts,
|
||||
"use_cuda_kernel": self.use_cuda_kernel,
|
||||
"use_cuda_graph": self.use_cuda_graph,
|
||||
"kv_seq_len": self.kv_seq_len,
|
||||
"head_dim": self.head_dim,
|
||||
"high_precision": self.high_precision,
|
||||
"dtype": str(self.dtype).split(".")[-1],
|
||||
"use_spec_dec": self.use_spec_dec,
|
||||
"num_tokens_to_verify": self.num_tokens_to_verify,
|
||||
"batch_token_ids": self.batch_token_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData":
|
||||
"""
|
||||
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
|
||||
"""
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
dtype = getattr(torch, rpc_dict["dtype"])
|
||||
return InputMetaData(
|
||||
block_tables=torch.tensor(
|
||||
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
|
||||
),
|
||||
sequence_lengths=torch.tensor(
|
||||
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
|
||||
),
|
||||
batch_size=rpc_dict["batch_size"],
|
||||
is_prompts=rpc_dict["is_prompts"],
|
||||
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
|
||||
use_cuda_graph=rpc_dict["use_cuda_graph"],
|
||||
kv_seq_len=rpc_dict["kv_seq_len"],
|
||||
head_dim=rpc_dict["head_dim"],
|
||||
high_precision=rpc_dict["high_precision"],
|
||||
dtype=dtype,
|
||||
use_spec_dec=rpc_dict["use_spec_dec"],
|
||||
num_tokens_to_verify=rpc_dict["num_tokens_to_verify"],
|
||||
batch_token_ids=rpc_dict["batch_token_ids"],
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@@ -80,7 +146,7 @@ class InputMetaData:
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
class InferenceConfig(RPC_PARAM):
|
||||
"""The inference configuration.
|
||||
|
||||
Args:
|
||||
@@ -193,10 +259,6 @@ class InferenceConfig:
|
||||
if self.dtype == torch.float32:
|
||||
self.high_precision = False
|
||||
|
||||
# check distributed
|
||||
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
# check prompt template
|
||||
if self.prompt_template is None:
|
||||
return
|
||||
@@ -226,6 +288,43 @@ class InferenceConfig:
|
||||
|
||||
return GenerationConfig.from_dict(meta_config)
|
||||
|
||||
def to_rpc_param(self) -> dict:
|
||||
kwargs = {
|
||||
"dtype": str(self.dtype).split(".")[-1],
|
||||
"max_n_spec_tokens": self.max_n_spec_tokens,
|
||||
"max_batch_size": self.max_batch_size,
|
||||
"max_input_len": self.max_input_len,
|
||||
"max_output_len": self.max_output_len,
|
||||
"tp_size": self.tp_size,
|
||||
"pp_size": self.pp_size,
|
||||
"pad_input": self.pad_input,
|
||||
"early_stopping": self.early_stopping,
|
||||
"do_sample": self.do_sample,
|
||||
"beam_width": self.beam_width,
|
||||
"kv_cache_dtype": str(self.kv_cache_dtype).split(".")[-1],
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def from_rpc_param(rpc_dict: dict) -> "InferenceConfig":
|
||||
"""
|
||||
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
|
||||
"""
|
||||
return InferenceConfig(
|
||||
dtype=getattr(torch, rpc_dict["dtype"]),
|
||||
max_n_spec_tokens=rpc_dict["max_n_spec_tokens"],
|
||||
max_batch_size=rpc_dict["max_batch_size"],
|
||||
max_input_len=rpc_dict["max_input_len"],
|
||||
max_output_len=rpc_dict["max_output_len"],
|
||||
tp_size=rpc_dict["tp_size"],
|
||||
pp_size=rpc_dict["pp_size"],
|
||||
pad_input=rpc_dict["pad_input"],
|
||||
early_stopping=rpc_dict["early_stopping"],
|
||||
do_sample=rpc_dict["do_sample"],
|
||||
beam_width=rpc_dict["beam_width"],
|
||||
kv_cache_dtype=getattr(torch, rpc_dict["kv_cache_dtype"], None),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
|
||||
# Get the list of attributes of this dataclass.
|
||||
|
Reference in New Issue
Block a user