[Feat]Inference RPC Server Support (#5705)

* rpc support source
* kv cache logical/physical disaggregation
* sampler refactor
* colossalai launch built in
* Unitest
* Rpyc support

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Runyu Lu
2024-05-14 10:00:55 +08:00
committed by GitHub
parent de4bf3dedf
commit 18d67d0e8e
15 changed files with 1032 additions and 63 deletions

View File

@@ -2,11 +2,11 @@
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
import torch.distributed as dist
from transformers.generation import GenerationConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
@@ -30,8 +30,25 @@ _DEFAULT_PROMPT_TEMPLATES = {
}
class RPC_PARAM(ABC):
"""
NOTE(lry89757) We use rpyc to transport param between client and server.
Rpyc only support the type of `POD` in python as the param, so we should take some smart ways to transport the data like tensor or some sophisticated classes.
Drawing on the logic of `__setstate__`, `__getstate__`, we will let some classes(will be rpc param later) inherit this base class, and rewrite the to_rpc_param and from_rpc_param. We will invoke `to_rpc_param` in client to pass the params and recover the param in server side by `from_rpc_param`.
"""
@abstractmethod
def to_rpc_param(self):
return NotImplementedError
@staticmethod
@abstractmethod
def from_rpc_param():
return NotImplementedError
@dataclass
class InputMetaData:
class InputMetaData(RPC_PARAM):
"""The input info for a single step
Args:
@@ -48,6 +65,7 @@ class InputMetaData:
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
use_spec_dec (bool): Indicate whether to use speculative decoding.
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process.
"""
block_tables: torch.Tensor = None
@@ -63,6 +81,54 @@ class InputMetaData:
dtype: torch.dtype = torch.float32
use_spec_dec: bool = False
num_tokens_to_verify: int = 0
batch_token_ids: Optional[
List[List[int]]
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
def to_rpc_param(self) -> Dict[str, any]:
return {
"block_tables": self.block_tables.tolist(),
"sequence_lengths": self.sequence_lengths.tolist(),
"batch_size": self.batch_size,
"is_prompts": self.is_prompts,
"use_cuda_kernel": self.use_cuda_kernel,
"use_cuda_graph": self.use_cuda_graph,
"kv_seq_len": self.kv_seq_len,
"head_dim": self.head_dim,
"high_precision": self.high_precision,
"dtype": str(self.dtype).split(".")[-1],
"use_spec_dec": self.use_spec_dec,
"num_tokens_to_verify": self.num_tokens_to_verify,
"batch_token_ids": self.batch_token_ids,
}
@staticmethod
def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData":
"""
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
"""
from colossalai.accelerator import get_accelerator
dtype = getattr(torch, rpc_dict["dtype"])
return InputMetaData(
block_tables=torch.tensor(
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
),
sequence_lengths=torch.tensor(
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
),
batch_size=rpc_dict["batch_size"],
is_prompts=rpc_dict["is_prompts"],
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
use_cuda_graph=rpc_dict["use_cuda_graph"],
kv_seq_len=rpc_dict["kv_seq_len"],
head_dim=rpc_dict["head_dim"],
high_precision=rpc_dict["high_precision"],
dtype=dtype,
use_spec_dec=rpc_dict["use_spec_dec"],
num_tokens_to_verify=rpc_dict["num_tokens_to_verify"],
batch_token_ids=rpc_dict["batch_token_ids"],
)
def __repr__(self) -> str:
return (
@@ -80,7 +146,7 @@ class InputMetaData:
@dataclass
class InferenceConfig:
class InferenceConfig(RPC_PARAM):
"""The inference configuration.
Args:
@@ -193,10 +259,6 @@ class InferenceConfig:
if self.dtype == torch.float32:
self.high_precision = False
# check distributed
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
# check prompt template
if self.prompt_template is None:
return
@@ -226,6 +288,43 @@ class InferenceConfig:
return GenerationConfig.from_dict(meta_config)
def to_rpc_param(self) -> dict:
kwargs = {
"dtype": str(self.dtype).split(".")[-1],
"max_n_spec_tokens": self.max_n_spec_tokens,
"max_batch_size": self.max_batch_size,
"max_input_len": self.max_input_len,
"max_output_len": self.max_output_len,
"tp_size": self.tp_size,
"pp_size": self.pp_size,
"pad_input": self.pad_input,
"early_stopping": self.early_stopping,
"do_sample": self.do_sample,
"beam_width": self.beam_width,
"kv_cache_dtype": str(self.kv_cache_dtype).split(".")[-1],
}
return kwargs
@staticmethod
def from_rpc_param(rpc_dict: dict) -> "InferenceConfig":
"""
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
"""
return InferenceConfig(
dtype=getattr(torch, rpc_dict["dtype"]),
max_n_spec_tokens=rpc_dict["max_n_spec_tokens"],
max_batch_size=rpc_dict["max_batch_size"],
max_input_len=rpc_dict["max_input_len"],
max_output_len=rpc_dict["max_output_len"],
tp_size=rpc_dict["tp_size"],
pp_size=rpc_dict["pp_size"],
pad_input=rpc_dict["pad_input"],
early_stopping=rpc_dict["early_stopping"],
do_sample=rpc_dict["do_sample"],
beam_width=rpc_dict["beam_width"],
kv_cache_dtype=getattr(torch, rpc_dict["kv_cache_dtype"], None),
)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
# Get the list of attributes of this dataclass.