mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[Feat]Inference RPC Server Support (#5705)
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
@@ -424,7 +425,7 @@ class InferenceEngine:
|
||||
|
||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
# append new inputs to the batch, temporarily
|
||||
batch.append_batch_tokens(next_tokens)
|
||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||
@@ -472,7 +473,7 @@ class InferenceEngine:
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
|
||||
# 5. Compare and process the results
|
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||
@@ -689,6 +690,13 @@ class InferenceEngine:
|
||||
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
|
||||
batch_token_ids = None
|
||||
config_dict = self.generation_config.to_dict()
|
||||
# process repetition_penalty, no_repeat_ngram_size
|
||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph = False
|
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||
@@ -708,6 +716,7 @@ class InferenceEngine:
|
||||
dtype=batch.dtype,
|
||||
use_spec_dec=batch.use_spec_dec,
|
||||
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||
batch_token_ids=batch_token_ids,
|
||||
)
|
||||
|
||||
return input_ids, output_tensor, input_meta_data
|
||||
@@ -738,7 +747,9 @@ class InferenceEngine:
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
||||
next_tokens = search_tokens(
|
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||
)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
|
@@ -7,10 +7,11 @@ from transformers.generation import GenerationConfig
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
||||
from colossalai.inference.struct import RequestStatus, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
__all__ = ["RunningList", "RequestHandler"]
|
||||
|
||||
@@ -295,17 +296,6 @@ class RequestHandler:
|
||||
|
||||
return None
|
||||
|
||||
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
|
||||
if generation_config.num_beams == 1:
|
||||
if generation_config.do_sample:
|
||||
sample_tokens = multinomial_sample(generation_config, probs)
|
||||
else:
|
||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||
else:
|
||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
|
||||
|
||||
return sample_tokens
|
||||
|
||||
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
||||
if (
|
||||
sequence.output_token_id[-1] == generation_config.eos_token_id
|
||||
@@ -328,33 +318,6 @@ class RequestHandler:
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
|
||||
"""
|
||||
Sample tokens for finished requests.
|
||||
"""
|
||||
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
config_dict = generation_config.to_dict()
|
||||
# process repetition_penalty, no_repeat_ngram_size
|
||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type], cur_batch)
|
||||
|
||||
# do logit processor
|
||||
if generation_config.do_sample:
|
||||
# process temperature, top_k, top_p
|
||||
for type in ["temperature", "top_k", "top_p"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type])
|
||||
|
||||
# calculate probs
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# sample the next tokens
|
||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||
return sample_tokens
|
||||
|
||||
def append_next_tokens(self, sample_tokens: torch.Tensor):
|
||||
assert sample_tokens.dim() == 1
|
||||
n_elements = sample_tokens.size(0)
|
||||
@@ -386,3 +349,53 @@ class RequestHandler:
|
||||
self.done_list.extend(finished_seqs)
|
||||
|
||||
return finished_seqs
|
||||
|
||||
|
||||
class RPCRequestHandler(RequestHandler):
|
||||
"""
|
||||
RPC Version of request handler
|
||||
"""
|
||||
|
||||
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
||||
self.waiting_list: List[List] = [[], [], []]
|
||||
self.done_list: List[Sequence] = []
|
||||
self.dtype = inference_config.dtype
|
||||
self.max_batch_size = inference_config.max_batch_size
|
||||
|
||||
# initialize cache
|
||||
self._init_cache(model_config)
|
||||
|
||||
# initialize batch
|
||||
torch.cuda.current_device()
|
||||
kv_max_split_num = (
|
||||
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
|
||||
) // inference_config.block_size
|
||||
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
||||
|
||||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||
# which may cause bugs and this issue should be fixed later.
|
||||
self.running_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
block_size=inference_config.block_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
fd_interm_tensor=None,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.prefill_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
block_size=inference_config.block_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
fd_interm_tensor=None,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = RPCKVCacheManager(self.inference_config, model_config)
|
||||
|
291
colossalai/inference/core/rpc_engine.py
Normal file
291
colossalai/inference/core/rpc_engine.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import asyncio
|
||||
from itertools import count
|
||||
from time import sleep
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import rpyc
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from torch import multiprocessing as mp
|
||||
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.executor.rpc_worker import rpcWorkerService
|
||||
from colossalai.inference.utils import find_available_ports
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .engine import InferenceEngine
|
||||
from .request_handler import RPCRequestHandler
|
||||
|
||||
__all__ = ["RPCInferenceEngine"]
|
||||
|
||||
|
||||
def run_server(host, port, event: mp.Event = None):
|
||||
server = ThreadedServer(
|
||||
rpcWorkerService, port=port, protocol_config={"allow_public_attrs": True, "allow_all_attrs": True}
|
||||
)
|
||||
if event:
|
||||
event.set()
|
||||
server.start()
|
||||
|
||||
|
||||
class RPCInferenceEngine(InferenceEngine):
|
||||
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving.
|
||||
Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference.
|
||||
|
||||
Args:
|
||||
model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format
|
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: InferenceConfig,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy = None,
|
||||
) -> None:
|
||||
"""
|
||||
If you input a real model loaded by transformers, the init will take quite a long time
|
||||
Currently we don't support model(nn.Module) format as the param.
|
||||
"""
|
||||
|
||||
torch.multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
self.inference_config = inference_config
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
if isinstance(model_or_path, str):
|
||||
self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
elif isinstance(model_or_path, nn.Module):
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
|
||||
)
|
||||
# self.model_config = model_or_path.config
|
||||
else:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n"
|
||||
)
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
|
||||
self.tp_size = inference_config.tp_size
|
||||
self.events = [mp.Event() for _ in range(self.tp_size)]
|
||||
|
||||
# This operation will init the dist env and models
|
||||
self.workers: List[rpcWorkerService] = []
|
||||
self.init_workers()
|
||||
|
||||
asyncio.run(self.init_model(model_or_path, model_policy))
|
||||
|
||||
# init the scheduler and logic block manager
|
||||
self.request_handler = self.init_scheduler(self.inference_config, self.model_config)
|
||||
|
||||
# init the physical cache
|
||||
alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape()
|
||||
self.init_device_cache(alloc_shape)
|
||||
|
||||
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||
self.high_precision = inference_config.high_precision
|
||||
self.dtype = inference_config.dtype
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = False
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
self.counter = count()
|
||||
self._verify_args()
|
||||
|
||||
self.logger.info("engine init over ")
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
"""Verify the input args"""
|
||||
if not isinstance(self.inference_config, InferenceConfig):
|
||||
raise TypeError("Invalid type of inference config provided.")
|
||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
|
||||
def init_workers(self):
|
||||
rpc_ports = find_available_ports(self.tp_size)
|
||||
self.worker_processes = []
|
||||
# mp.set_start_method('spawn')
|
||||
for event, rpc_port in zip(self.events, rpc_ports):
|
||||
p = mp.Process(target=run_server, args=("localhost", rpc_port, event))
|
||||
p.start()
|
||||
self.worker_processes.append(p)
|
||||
self.logger.info(f"Starting RPC Worker on localhost:{rpc_port}...")
|
||||
|
||||
# Wait for all servers to start
|
||||
for event in self.events:
|
||||
event.wait()
|
||||
event.clear()
|
||||
|
||||
sleep(0.05)
|
||||
|
||||
self.logger.info(f"init rpc server done.")
|
||||
|
||||
for rpc_port in rpc_ports:
|
||||
try:
|
||||
conn = rpyc.connect(
|
||||
"localhost",
|
||||
rpc_port,
|
||||
config={"allow_pickle": True, "allow_public_attrs": True, "allow_all_attrs": True},
|
||||
)
|
||||
self.workers.append(conn.root)
|
||||
except:
|
||||
raise Exception("conn error!")
|
||||
self.logger.info(f"Build RPC Connection Success! Begin to load model...")
|
||||
asyncio.run(self.init_worker_env())
|
||||
self.logger.info(f"init dist env over")
|
||||
|
||||
async def async_parallel_wrapper(self, f, *args, **kwargs):
|
||||
async_res = rpyc.async_(f)(*args, **kwargs)
|
||||
await asyncio.to_thread(async_res.wait)
|
||||
assert async_res.ready
|
||||
return async_res.value
|
||||
|
||||
async def init_worker_env(self):
|
||||
assert len(self.workers) == self.tp_size, "init workers first"
|
||||
|
||||
dist_group_port = find_available_ports(1)[0]
|
||||
init_tasks = [
|
||||
self.async_parallel_wrapper(
|
||||
worker.init_dist_env, rank, self.inference_config.tp_size, "127.0.0.1", dist_group_port
|
||||
)
|
||||
for rank, worker in enumerate(self.workers)
|
||||
]
|
||||
|
||||
await asyncio.gather(*init_tasks)
|
||||
|
||||
async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
||||
assert len(self.workers) == self.tp_size, "init workers first"
|
||||
|
||||
inference_config_param = self.inference_config.to_rpc_param()
|
||||
model_path = model_or_path
|
||||
model_policy_param = model_policy.to_rpc_param() if model_policy else None
|
||||
|
||||
init_tasks = [
|
||||
self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param)
|
||||
for rank, worker in enumerate(self.workers)
|
||||
]
|
||||
|
||||
await asyncio.gather(*init_tasks)
|
||||
|
||||
def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler:
|
||||
return RPCRequestHandler(inference_config, model_config)
|
||||
|
||||
async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]):
|
||||
assert len(self.workers) == self.tp_size, "init workers first"
|
||||
|
||||
init_tasks = [self.async_parallel_wrapper(worker.init_cache, alloc_shape) for worker in self.workers]
|
||||
|
||||
await asyncio.gather(*init_tasks)
|
||||
|
||||
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
|
||||
asyncio.run(self._init_device_cache(alloc_shape))
|
||||
|
||||
def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]:
|
||||
input_ids = batch.get_1D_inputs()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
if batch.is_prompts:
|
||||
n_tokens = sequence_lengths.sum().item()
|
||||
else:
|
||||
n_tokens = batch.current_batch_size
|
||||
if batch.use_spec_dec:
|
||||
n_tokens = batch.num_tokens_to_verify + 1
|
||||
assert n_tokens == input_ids.size(0)
|
||||
n_tokens = n_tokens * batch.current_batch_size
|
||||
|
||||
batch_token_ids = None
|
||||
config_dict = self.generation_config.to_dict()
|
||||
# process repetition_penalty, no_repeat_ngram_size
|
||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph = False
|
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||
use_cuda_graph = True
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=batch.get_block_table_tensor(),
|
||||
sequence_lengths=sequence_lengths,
|
||||
fd_inter_tensor=None,
|
||||
batch_size=batch.current_batch_size,
|
||||
is_prompts=batch.is_prompts,
|
||||
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
high_precision=self.high_precision,
|
||||
kv_seq_len=sequence_lengths.max().item(),
|
||||
head_dim=batch.head_dim,
|
||||
dtype=batch.dtype,
|
||||
use_spec_dec=batch.use_spec_dec,
|
||||
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||
batch_token_ids=batch_token_ids,
|
||||
)
|
||||
|
||||
return input_ids.tolist(), input_meta_data
|
||||
|
||||
async def step_(self, input_token_ids, input_meta_data: InputMetaData):
|
||||
assert len(self.workers) == self.tp_size, "init workers first"
|
||||
|
||||
init_tasks = [
|
||||
self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param())
|
||||
for worker in self.workers
|
||||
]
|
||||
ret = await asyncio.gather(*init_tasks)
|
||||
|
||||
return ret[0]
|
||||
|
||||
def step(self) -> List[str]:
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
input_token_ids, input_meta_data = self.prepare_input(batch)
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data))
|
||||
|
||||
# update the request_handler
|
||||
next_tokens = torch.tensor(next_tokens, dtype=torch.int)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
finished_sequences = self.request_handler.update()
|
||||
return finished_sequences
|
||||
|
||||
def kill_workers(self):
|
||||
"""
|
||||
I don't find a good way to implicit invoke self.kill_workers
|
||||
"""
|
||||
assert len(self.workers) != 0
|
||||
for proc in self.worker_processes:
|
||||
proc.kill()
|
||||
proc.join()
|
||||
self.logger.info(f"worker killed, serving end")
|
||||
|
||||
def __del__(self):
|
||||
self.kill_workers()
|
Reference in New Issue
Block a user