This commit is contained in:
Runyu Lu 2025-04-29 11:26:27 +00:00 committed by GitHub
commit fab3b58efc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 288 additions and 59 deletions

View File

@ -558,3 +558,68 @@ class BatchBucket:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})" return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})"
class RPCBatchBucket(BatchBucket):
def __init__(self, *args, **argv):
self.is_rpc = True
self.device = "cpu"
super().__init__(*args, **argv)
# For compatibility
def get_1D_inputs(self) -> List[int]:
assert len(self._sequences_dict) > 0, "No sequence in the batch"
first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence
if first_seq.output_len == 0:
# Assume prefill stage
assert all(
seq.output_len == 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.input_token_id)
return out_li
else:
# Assume decoding stage
if self.use_spec_dec:
# For Speculative Decoding
# the number of tokens to be verified in parallel plus the correct token in the last step
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
assert all(
seq.output_len > 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
assert self.is_compact, "BatchBucket is not compact"
out = [0] * self.current_batch_size
for seq_id, index_in_b in self._sequences_indexes.items():
seq: Sequence = self._sequences_dict[seq_id]
out[index_in_b] = seq.output_token_id[-1]
return out
# For compatibility
def get_sequence_lengths(self) -> List[int]:
assert self.is_compact # Debug usage
sequence_lengths = self.seq_lengths[: self.current_batch_size]
return sequence_lengths
def get_1D_inputs_spec_dec(self, n: int) -> List[int]:
# Used for main model verification in **Decoding Stage**
# `n` is the number of tokens to be verified,
# and so that prepare the last `n` tokens of each sequence as the inputs
assert len(self._sequences_dict) > 0, "No sequence in the batch"
assert all(
seq.output_len >= n for seq in self._sequences_dict.values()
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.output_token_id[-n:])
return out_li
# For compatibility
def get_block_table_tensor(self) -> torch.Tensor:
assert self.is_compact # Debug usage
block_table = self.block_tables[: self.current_batch_size]
return block_table

View File

@ -89,8 +89,14 @@ class InputMetaData(RPC_PARAM):
def to_rpc_param(self) -> Dict[str, any]: def to_rpc_param(self) -> Dict[str, any]:
return { return {
"block_tables": self.block_tables.tolist(), "block_tables": self.block_tables,
"sequence_lengths": self.sequence_lengths.tolist(), # "block_tables": self.block_tables.tolist()
# if isinstance(self.block_tables, torch.Tensor)
# else self.block_tables,
"sequence_lengths": self.sequence_lengths,
# "sequence_lengths": self.sequence_lengths.tolist()
# if isinstance(self.block_tables, torch.Tensor)
# else self.sequence_lengths,
"batch_size": self.batch_size, "batch_size": self.batch_size,
"is_prompts": self.is_prompts, "is_prompts": self.is_prompts,
"use_cuda_kernel": self.use_cuda_kernel, "use_cuda_kernel": self.use_cuda_kernel,
@ -112,12 +118,17 @@ class InputMetaData(RPC_PARAM):
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
dtype = getattr(torch, rpc_dict["dtype"]) dtype = getattr(torch, rpc_dict["dtype"])
device = get_accelerator().get_current_device()
return InputMetaData( return InputMetaData(
block_tables=torch.tensor( block_tables=(
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device() torch.tensor(rpc_dict["block_tables"], dtype=torch.int, device=device)
if isinstance(rpc_dict["block_tables"], list)
else rpc_dict["block_tables"].to(device)
), ),
sequence_lengths=torch.tensor( sequence_lengths=(
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() torch.tensor(rpc_dict["sequence_lengths"], dtype=torch.int, device=device)
if isinstance(rpc_dict["sequence_lengths"], list)
else rpc_dict["sequence_lengths"].to(device)
), ),
batch_size=rpc_dict["batch_size"], batch_size=rpc_dict["batch_size"],
is_prompts=rpc_dict["is_prompts"], is_prompts=rpc_dict["is_prompts"],

View File

@ -78,7 +78,6 @@ class InferenceEngine:
Args: Args:
request_ids (List[int], optional): The request ID. Defaults to None. request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
""" """
assert self.engine is not None, "Please init Engine first" assert self.engine is not None, "Please init Engine first"

View File

@ -4,7 +4,7 @@ import torch
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.batch_bucket import BatchBucket, RPCBatchBucket
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
@ -427,7 +427,7 @@ class RPCRequestHandler(RequestHandler):
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later. # which may cause bugs and this issue should be fixed later.
self.running_bb = BatchBucket( self.running_bb = RPCBatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size, num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim, head_dim=head_dim,
max_batch_size=self.max_batch_size, max_batch_size=self.max_batch_size,
@ -437,7 +437,7 @@ class RPCRequestHandler(RequestHandler):
fd_interm_tensor=None, fd_interm_tensor=None,
dtype=self.dtype, dtype=self.dtype,
) )
self.prefill_bb = BatchBucket( self.prefill_bb = RPCBatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size, num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim, head_dim=head_dim,
max_batch_size=self.max_batch_size, max_batch_size=self.max_batch_size,

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import pickle
from itertools import count from itertools import count
from time import sleep from time import sleep
from typing import List, Tuple, Union from typing import List, Tuple, Union
@ -11,7 +12,7 @@ from torch import multiprocessing as mp
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.batch_bucket import RPCBatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.executor.rpc_worker import rpcWorkerService from colossalai.inference.executor.rpc_worker import rpcWorkerService
from colossalai.inference.utils import find_available_ports from colossalai.inference.utils import find_available_ports
@ -120,6 +121,9 @@ class RPCInferenceEngine(InferenceEngine):
self.counter = count() self.counter = count()
self._verify_args() self._verify_args()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.logger.info("engine init over ") self.logger.info("engine init over ")
def _verify_args(self) -> None: def _verify_args(self) -> None:
@ -162,8 +166,16 @@ class RPCInferenceEngine(InferenceEngine):
raise Exception("conn error!") raise Exception("conn error!")
self.logger.info(f"Build RPC Connection Success! Begin to load model...") self.logger.info(f"Build RPC Connection Success! Begin to load model...")
asyncio.run(self.init_worker_env()) asyncio.run(self.init_worker_env())
self._init_worker_forward()
self.logger.info(f"init dist env over") self.logger.info(f"init dist env over")
def _init_worker_forward(self):
"""
Async wrappers for forward, because it will be invoked many times.
"""
assert len(self.workers) == self.tp_size, "init workers first"
self.worker_forwards = [rpyc.async_(worker.execute_model_forward) for worker in self.workers]
async def async_parallel_wrapper(self, f, *args, **kwargs): async def async_parallel_wrapper(self, f, *args, **kwargs):
async_res = rpyc.async_(f)(*args, **kwargs) async_res = rpyc.async_(f)(*args, **kwargs)
await asyncio.to_thread(async_res.wait) await asyncio.to_thread(async_res.wait)
@ -210,7 +222,8 @@ class RPCInferenceEngine(InferenceEngine):
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
asyncio.run(self._init_device_cache(alloc_shape)) asyncio.run(self._init_device_cache(alloc_shape))
def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: def prepare_input(self, batch: RPCBatchBucket) -> Tuple[List[int], InputMetaData]:
assert batch.is_rpc, "the batch must be RPCBatchBucket"
input_ids = batch.get_1D_inputs() input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths() sequence_lengths = batch.get_sequence_lengths()
@ -220,7 +233,7 @@ class RPCInferenceEngine(InferenceEngine):
n_tokens = batch.current_batch_size n_tokens = batch.current_batch_size
if batch.use_spec_dec: if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1 n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0) assert n_tokens == len(input_ids)
n_tokens = n_tokens * batch.current_batch_size n_tokens = n_tokens * batch.current_batch_size
batch_token_ids = None batch_token_ids = None
@ -252,40 +265,60 @@ class RPCInferenceEngine(InferenceEngine):
batch_token_ids=batch_token_ids, batch_token_ids=batch_token_ids,
) )
return input_ids.tolist(), input_meta_data return input_ids, input_meta_data
async def step_(self, input_token_ids, input_meta_data: InputMetaData): async def async_parallel_forward(self, async_f, *args, **kwargs):
async_res = async_f(*args, **kwargs)
await asyncio.to_thread(async_res.wait)
assert async_res.ready
return async_res.value
async def step_async(self, input_token_ids, input_meta_data: InputMetaData):
assert len(self.workers) == self.tp_size, "init workers first" assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [ init_tasks = []
self.async_parallel_wrapper( for rank, async_forward in enumerate(self.worker_forwards):
worker.execute_model_forward, if rank == 0:
input_token_ids, init_tasks.append(
input_meta_data.to_rpc_param(), self.async_parallel_forward(
self.generation_config_dict, async_forward,
) pickle.dumps(input_token_ids),
for worker in self.workers pickle.dumps(input_meta_data.to_rpc_param()),
] pickle.dumps(self.generation_config_dict),
)
)
else:
init_tasks.append(
self.async_parallel_forward(
async_forward,
None,
None,
None,
)
)
ret = await asyncio.gather(*init_tasks) ret = await asyncio.gather(*init_tasks)
return ret[0] return ret[0]
def step(self) -> List[str]: def step(self) -> List[str]:
batch = self.request_handler.schedule() with self.t_prepare:
batch = self.request_handler.schedule()
input_token_ids, input_meta_data = self.prepare_input(batch) input_token_ids, input_meta_data = self.prepare_input(batch)
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data)) with self.t_exe:
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data))
# update the request_handler # update the request_handler
next_tokens = torch.tensor(next_tokens, dtype=torch.int)
self.request_handler.append_next_tokens(next_tokens) self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()
return finished_sequences return finished_sequences
def kill_workers(self): def kill_workers(self):
""" """
I don't find a good way to implicit invoke self.kill_workers NOTE(@lry89757) Don't find a good way to implicit invoke self.kill_workers
""" """
assert len(self.workers) != 0 assert len(self.workers) != 0
for proc in self.worker_processes: for proc in self.worker_processes:

View File

@ -1,4 +1,6 @@
from typing import List, Tuple, Union import pickle
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union
import rpyc import rpyc
import torch import torch
@ -51,6 +53,25 @@ class rpcWorkerService(rpyc.Service):
def exposed_init_dist_env(self, rank, world_size, master_address, master_port): def exposed_init_dist_env(self, rank, world_size, master_address, master_port):
logger.info(f"init process group for rank {rank}") logger.info(f"init process group for rank {rank}")
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address)
self.rank = rank
self.profiling = False
self.profiler = (
torch.profiler.profile(
record_shapes=True,
with_stack=True,
with_modules=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
# schedule=torch.profiler.schedule(wait=0, repeat=1, active=1),
# on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode),
)
if self.profiling
else nullcontext()
)
logger.info(f"init process group done for rank {rank}") logger.info(f"init process group done for rank {rank}")
def exposed_init_model( def exposed_init_model(
@ -98,38 +119,53 @@ class rpcWorkerService(rpyc.Service):
logger.info("physical cache init over") logger.info("physical cache init over")
def exposed_execute_model_forward( def exposed_execute_model_forward(
self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict self,
input_token_ids_param: Optional[List[int]] = None,
input_meta_data_param: Optional[dict] = None,
generation_config_param: Optional[dict] = None,
): ):
# prepare the data for model forward with self.profiler:
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) # prepare the data for model forward
input_meta_data.fd_inter_tensor = self.fd_inter_tensor input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers(
if input_meta_data.is_prompts: input_token_ids_param=input_token_ids_param,
n_tokens = input_meta_data.sequence_lengths.sum().item() input_meta_data_param=input_meta_data_param,
else: generation_config_param=generation_config_param,
n_tokens = input_meta_data.batch_size )
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
# execute the model if input_meta_data.is_prompts:
logits = self.model( n_tokens = input_meta_data.sequence_lengths.sum().item()
input_token_ids, else:
self.output_tensor[:n_tokens], n_tokens = input_meta_data.batch_size
input_meta_data,
self.k_cache,
self.v_cache,
)
# sampler # execute the model
if self.inference_config.pad_input: logits = self.model(
logits = logits[:, -1, :] input_token_ids,
next_tokens = search_tokens( self.output_tensor[:n_tokens],
generation_config_param, input_meta_data,
logits, self.k_cache,
input_meta_data.is_prompts, self.v_cache,
input_meta_data.batch_token_ids, )
)
# return the tokens generated to scheduler if self.profiling:
return next_tokens.tolist() self.profiler.step()
self.record()
if self.rank == 0:
# sampler
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = search_tokens(
generation_config,
logits,
input_meta_data.is_prompts,
input_meta_data.batch_token_ids,
)
# return the tokens generated to scheduler
# only rank 0 need to pass the data back
# to reduce the overhead of rpc param passing
return next_tokens.cpu()
def _init_output_tensor(self): def _init_output_tensor(self):
alloc_shape = ( alloc_shape = (
@ -166,6 +202,85 @@ class rpcWorkerService(rpyc.Service):
self.fd_inter_tensor = fd_inter_tensor self.fd_inter_tensor = fd_inter_tensor
def _broadcast_param_to_all_workers(
self,
input_token_ids_param: Optional[List[int]] = None,
input_meta_data_param: Optional[dict] = None,
generation_config_param: Optional[dict] = None,
):
if self.rank == 0:
input_token_ids_param = pickle.loads(input_token_ids_param)
input_meta_data_param = pickle.loads(input_meta_data_param)
generation_config_param = pickle.loads(generation_config_param)
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
generation_config = generation_config_param
if dist.get_world_size() > 1:
broadcast_list = {}
for k, v in input_meta_data_param.items():
if not isinstance(v, torch.Tensor):
broadcast_list[k] = v
# Pass the tensor shape and type in advance for
# other workers to prepare the empty tensor and async transport tensors
broadcast_list["block_tables"] = (
input_meta_data.block_tables.size(),
input_meta_data.block_tables.dtype,
)
broadcast_list["sequence_lengths"] = (
input_meta_data.sequence_lengths.size(),
input_meta_data.sequence_lengths.dtype,
)
broadcast_list["input_token_ids"] = (input_token_ids.size(), input_token_ids.dtype)
# Generation Config Param
broadcast_list["generation_config"] = generation_config_param
# send some meta data and some tensor shape
torch.distributed.broadcast_object_list([broadcast_list], src=self.rank)
# send the real tensor
torch.distributed.broadcast(input_meta_data.block_tables, src=self.rank)
torch.distributed.broadcast(input_meta_data.sequence_lengths, src=self.rank)
torch.distributed.broadcast(input_token_ids, src=self.rank)
else:
assert input_meta_data_param is None, "Input Must Be None"
# recv the meta data
recv_list = [None]
torch.distributed.broadcast_object_list(recv_list, src=0)
input_meta_data_param = recv_list[0]
generation_config = input_meta_data_param["generation_config"]
blocktable_shape, blocktable_type = input_meta_data_param["block_tables"]
blocktables = torch.empty(blocktable_shape, dtype=blocktable_type, device=self.device)
sequence_lengths_shape, sequence_lengths_type = input_meta_data_param["sequence_lengths"]
sequence_lengths = torch.empty(sequence_lengths_shape, dtype=sequence_lengths_type, device=self.device)
input_token_ids_shape, input_token_ids_type = input_meta_data_param["input_token_ids"]
input_token_ids = torch.empty(input_token_ids_shape, dtype=input_token_ids_type, device=self.device)
# recv the real tensor
async1 = torch.distributed.broadcast(blocktables, src=0, async_op=True)
async2 = torch.distributed.broadcast(sequence_lengths, src=0, async_op=True)
async3 = torch.distributed.broadcast(input_token_ids, src=0, async_op=True)
input_meta_data_param["sequence_lengths"] = sequence_lengths
input_meta_data_param["block_tables"] = blocktables
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
async1.wait()
async2.wait()
async3.wait()
return input_token_ids, input_meta_data, generation_config
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
""" """
Shard model or/and Load weight Shard model or/and Load weight
@ -306,3 +421,9 @@ class rpcWorkerService(rpyc.Service):
logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}") logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}")
return data.item() return data.item()
def record(self):
if self.profiling:
file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json"
self.profiler.export_chrome_trace(file)
logger.info(f"trace has been saved into {file}")