[Feat]Inference RPC Server Support (#5705)

* rpc support source
* kv cache logical/physical disaggregation
* sampler refactor
* colossalai launch built in
* Unitest
* Rpyc support

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Runyu Lu
2024-05-14 10:00:55 +08:00
committed by GitHub
parent de4bf3dedf
commit 18d67d0e8e
15 changed files with 1032 additions and 63 deletions

View File

@@ -21,6 +21,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
@@ -424,7 +425,7 @@ class InferenceEngine:
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
@@ -472,7 +473,7 @@ class InferenceEngine:
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
@@ -689,6 +690,13 @@ class InferenceEngine:
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
batch_token_ids = None
config_dict = self.generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
@@ -708,6 +716,7 @@ class InferenceEngine:
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids, output_tensor, input_meta_data
@@ -738,7 +747,9 @@ class InferenceEngine:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()

View File

@@ -7,10 +7,11 @@ from transformers.generation import GenerationConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
__all__ = ["RunningList", "RequestHandler"]
@@ -295,17 +296,6 @@ class RequestHandler:
return None
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = multinomial_sample(generation_config, probs)
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
return sample_tokens
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
if (
sequence.output_token_id[-1] == generation_config.eos_token_id
@@ -328,33 +318,6 @@ class RequestHandler:
def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
"""
Sample tokens for finished requests.
"""
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type], cur_batch)
# do logit processor
if generation_config.do_sample:
# process temperature, top_k, top_p
for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
return sample_tokens
def append_next_tokens(self, sample_tokens: torch.Tensor):
assert sample_tokens.dim() == 1
n_elements = sample_tokens.size(0)
@@ -386,3 +349,53 @@ class RequestHandler:
self.done_list.extend(finished_seqs)
return finished_seqs
class RPCRequestHandler(RequestHandler):
"""
RPC Version of request handler
"""
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
self.inference_config = inference_config
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = []
self.dtype = inference_config.dtype
self.max_batch_size = inference_config.max_batch_size
# initialize cache
self._init_cache(model_config)
# initialize batch
torch.cuda.current_device()
kv_max_split_num = (
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
) // inference_config.block_size
head_dim = model_config.hidden_size // model_config.num_attention_heads
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=None,
dtype=self.dtype,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=None,
dtype=self.dtype,
)
def _init_cache(self, model_config):
self.cache_manager = RPCKVCacheManager(self.inference_config, model_config)

View File

@@ -0,0 +1,291 @@
import asyncio
from itertools import count
from time import sleep
from typing import List, Tuple, Union
import rpyc
import torch
import torch.nn as nn
from rpyc.utils.server import ThreadedServer
from torch import multiprocessing as mp
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.executor.rpc_worker import rpcWorkerService
from colossalai.inference.utils import find_available_ports
from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy
from .engine import InferenceEngine
from .request_handler import RPCRequestHandler
__all__ = ["RPCInferenceEngine"]
def run_server(host, port, event: mp.Event = None):
server = ThreadedServer(
rpcWorkerService, port=port, protocol_config={"allow_public_attrs": True, "allow_all_attrs": True}
)
if event:
event.set()
server.start()
class RPCInferenceEngine(InferenceEngine):
"""
InferenceEngine which manages the inference process..
NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving.
Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference.
Args:
model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
"""
def __init__(
self,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
) -> None:
"""
If you input a real model loaded by transformers, the init will take quite a long time
Currently we don't support model(nn.Module) format as the param.
"""
torch.multiprocessing.set_start_method("spawn", force=True)
self.inference_config = inference_config
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.verbose = verbose
self.logger = get_dist_logger(__name__)
try:
if isinstance(model_or_path, str):
self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
elif isinstance(model_or_path, nn.Module):
self.logger.error(
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
)
# self.model_config = model_or_path.config
else:
self.logger.error(
f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n"
)
except Exception as e:
self.logger.error(
f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n"
)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.tp_size = inference_config.tp_size
self.events = [mp.Event() for _ in range(self.tp_size)]
# This operation will init the dist env and models
self.workers: List[rpcWorkerService] = []
self.init_workers()
asyncio.run(self.init_model(model_or_path, model_policy))
# init the scheduler and logic block manager
self.request_handler = self.init_scheduler(self.inference_config, self.model_config)
# init the physical cache
alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape()
self.init_device_cache(alloc_shape)
self.use_cuda_graph = self.inference_config.use_cuda_graph
self.high_precision = inference_config.high_precision
self.dtype = inference_config.dtype
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.counter = count()
self._verify_args()
self.logger.info("engine init over ")
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
def init_workers(self):
rpc_ports = find_available_ports(self.tp_size)
self.worker_processes = []
# mp.set_start_method('spawn')
for event, rpc_port in zip(self.events, rpc_ports):
p = mp.Process(target=run_server, args=("localhost", rpc_port, event))
p.start()
self.worker_processes.append(p)
self.logger.info(f"Starting RPC Worker on localhost:{rpc_port}...")
# Wait for all servers to start
for event in self.events:
event.wait()
event.clear()
sleep(0.05)
self.logger.info(f"init rpc server done.")
for rpc_port in rpc_ports:
try:
conn = rpyc.connect(
"localhost",
rpc_port,
config={"allow_pickle": True, "allow_public_attrs": True, "allow_all_attrs": True},
)
self.workers.append(conn.root)
except:
raise Exception("conn error!")
self.logger.info(f"Build RPC Connection Success! Begin to load model...")
asyncio.run(self.init_worker_env())
self.logger.info(f"init dist env over")
async def async_parallel_wrapper(self, f, *args, **kwargs):
async_res = rpyc.async_(f)(*args, **kwargs)
await asyncio.to_thread(async_res.wait)
assert async_res.ready
return async_res.value
async def init_worker_env(self):
assert len(self.workers) == self.tp_size, "init workers first"
dist_group_port = find_available_ports(1)[0]
init_tasks = [
self.async_parallel_wrapper(
worker.init_dist_env, rank, self.inference_config.tp_size, "127.0.0.1", dist_group_port
)
for rank, worker in enumerate(self.workers)
]
await asyncio.gather(*init_tasks)
async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
assert len(self.workers) == self.tp_size, "init workers first"
inference_config_param = self.inference_config.to_rpc_param()
model_path = model_or_path
model_policy_param = model_policy.to_rpc_param() if model_policy else None
init_tasks = [
self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param)
for rank, worker in enumerate(self.workers)
]
await asyncio.gather(*init_tasks)
def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler:
return RPCRequestHandler(inference_config, model_config)
async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]):
assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [self.async_parallel_wrapper(worker.init_cache, alloc_shape) for worker in self.workers]
await asyncio.gather(*init_tasks)
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
asyncio.run(self._init_device_cache(alloc_shape))
def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
n_tokens = sequence_lengths.sum().item()
else:
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
n_tokens = n_tokens * batch.current_batch_size
batch_token_ids = None
config_dict = self.generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=None,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids.tolist(), input_meta_data
async def step_(self, input_token_ids, input_meta_data: InputMetaData):
assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [
self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param())
for worker in self.workers
]
ret = await asyncio.gather(*init_tasks)
return ret[0]
def step(self) -> List[str]:
batch = self.request_handler.schedule()
input_token_ids, input_meta_data = self.prepare_input(batch)
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data))
# update the request_handler
next_tokens = torch.tensor(next_tokens, dtype=torch.int)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences
def kill_workers(self):
"""
I don't find a good way to implicit invoke self.kill_workers
"""
assert len(self.workers) != 0
for proc in self.worker_processes:
proc.kill()
proc.join()
self.logger.info(f"worker killed, serving end")
def __del__(self):
self.kill_workers()