This commit is contained in:
Runyu Lu 2025-04-29 11:26:27 +00:00 committed by GitHub
commit fab3b58efc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 288 additions and 59 deletions

View File

@ -558,3 +558,68 @@ class BatchBucket:
def __repr__(self) -> str:
return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})"
class RPCBatchBucket(BatchBucket):
def __init__(self, *args, **argv):
self.is_rpc = True
self.device = "cpu"
super().__init__(*args, **argv)
# For compatibility
def get_1D_inputs(self) -> List[int]:
assert len(self._sequences_dict) > 0, "No sequence in the batch"
first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence
if first_seq.output_len == 0:
# Assume prefill stage
assert all(
seq.output_len == 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.input_token_id)
return out_li
else:
# Assume decoding stage
if self.use_spec_dec:
# For Speculative Decoding
# the number of tokens to be verified in parallel plus the correct token in the last step
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
assert all(
seq.output_len > 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
assert self.is_compact, "BatchBucket is not compact"
out = [0] * self.current_batch_size
for seq_id, index_in_b in self._sequences_indexes.items():
seq: Sequence = self._sequences_dict[seq_id]
out[index_in_b] = seq.output_token_id[-1]
return out
# For compatibility
def get_sequence_lengths(self) -> List[int]:
assert self.is_compact # Debug usage
sequence_lengths = self.seq_lengths[: self.current_batch_size]
return sequence_lengths
def get_1D_inputs_spec_dec(self, n: int) -> List[int]:
# Used for main model verification in **Decoding Stage**
# `n` is the number of tokens to be verified,
# and so that prepare the last `n` tokens of each sequence as the inputs
assert len(self._sequences_dict) > 0, "No sequence in the batch"
assert all(
seq.output_len >= n for seq in self._sequences_dict.values()
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.output_token_id[-n:])
return out_li
# For compatibility
def get_block_table_tensor(self) -> torch.Tensor:
assert self.is_compact # Debug usage
block_table = self.block_tables[: self.current_batch_size]
return block_table

View File

@ -89,8 +89,14 @@ class InputMetaData(RPC_PARAM):
def to_rpc_param(self) -> Dict[str, any]:
return {
"block_tables": self.block_tables.tolist(),
"sequence_lengths": self.sequence_lengths.tolist(),
"block_tables": self.block_tables,
# "block_tables": self.block_tables.tolist()
# if isinstance(self.block_tables, torch.Tensor)
# else self.block_tables,
"sequence_lengths": self.sequence_lengths,
# "sequence_lengths": self.sequence_lengths.tolist()
# if isinstance(self.block_tables, torch.Tensor)
# else self.sequence_lengths,
"batch_size": self.batch_size,
"is_prompts": self.is_prompts,
"use_cuda_kernel": self.use_cuda_kernel,
@ -112,12 +118,17 @@ class InputMetaData(RPC_PARAM):
from colossalai.accelerator import get_accelerator
dtype = getattr(torch, rpc_dict["dtype"])
device = get_accelerator().get_current_device()
return InputMetaData(
block_tables=torch.tensor(
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
block_tables=(
torch.tensor(rpc_dict["block_tables"], dtype=torch.int, device=device)
if isinstance(rpc_dict["block_tables"], list)
else rpc_dict["block_tables"].to(device)
),
sequence_lengths=torch.tensor(
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
sequence_lengths=(
torch.tensor(rpc_dict["sequence_lengths"], dtype=torch.int, device=device)
if isinstance(rpc_dict["sequence_lengths"], list)
else rpc_dict["sequence_lengths"].to(device)
),
batch_size=rpc_dict["batch_size"],
is_prompts=rpc_dict["is_prompts"],

View File

@ -78,7 +78,6 @@ class InferenceEngine:
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
"""
assert self.engine is not None, "Please init Engine first"

View File

@ -4,7 +4,7 @@ import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.batch_bucket import BatchBucket, RPCBatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
@ -427,7 +427,7 @@ class RPCRequestHandler(RequestHandler):
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_bb = BatchBucket(
self.running_bb = RPCBatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
@ -437,7 +437,7 @@ class RPCRequestHandler(RequestHandler):
fd_interm_tensor=None,
dtype=self.dtype,
)
self.prefill_bb = BatchBucket(
self.prefill_bb = RPCBatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,

View File

@ -1,4 +1,5 @@
import asyncio
import pickle
from itertools import count
from time import sleep
from typing import List, Tuple, Union
@ -11,7 +12,7 @@ from torch import multiprocessing as mp
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.batch_bucket import RPCBatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.executor.rpc_worker import rpcWorkerService
from colossalai.inference.utils import find_available_ports
@ -120,6 +121,9 @@ class RPCInferenceEngine(InferenceEngine):
self.counter = count()
self._verify_args()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.logger.info("engine init over ")
def _verify_args(self) -> None:
@ -162,8 +166,16 @@ class RPCInferenceEngine(InferenceEngine):
raise Exception("conn error!")
self.logger.info(f"Build RPC Connection Success! Begin to load model...")
asyncio.run(self.init_worker_env())
self._init_worker_forward()
self.logger.info(f"init dist env over")
def _init_worker_forward(self):
"""
Async wrappers for forward, because it will be invoked many times.
"""
assert len(self.workers) == self.tp_size, "init workers first"
self.worker_forwards = [rpyc.async_(worker.execute_model_forward) for worker in self.workers]
async def async_parallel_wrapper(self, f, *args, **kwargs):
async_res = rpyc.async_(f)(*args, **kwargs)
await asyncio.to_thread(async_res.wait)
@ -210,7 +222,8 @@ class RPCInferenceEngine(InferenceEngine):
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
asyncio.run(self._init_device_cache(alloc_shape))
def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]:
def prepare_input(self, batch: RPCBatchBucket) -> Tuple[List[int], InputMetaData]:
assert batch.is_rpc, "the batch must be RPCBatchBucket"
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
@ -220,7 +233,7 @@ class RPCInferenceEngine(InferenceEngine):
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
assert n_tokens == len(input_ids)
n_tokens = n_tokens * batch.current_batch_size
batch_token_ids = None
@ -252,40 +265,60 @@ class RPCInferenceEngine(InferenceEngine):
batch_token_ids=batch_token_ids,
)
return input_ids.tolist(), input_meta_data
return input_ids, input_meta_data
async def step_(self, input_token_ids, input_meta_data: InputMetaData):
async def async_parallel_forward(self, async_f, *args, **kwargs):
async_res = async_f(*args, **kwargs)
await asyncio.to_thread(async_res.wait)
assert async_res.ready
return async_res.value
async def step_async(self, input_token_ids, input_meta_data: InputMetaData):
assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [
self.async_parallel_wrapper(
worker.execute_model_forward,
input_token_ids,
input_meta_data.to_rpc_param(),
self.generation_config_dict,
)
for worker in self.workers
]
init_tasks = []
for rank, async_forward in enumerate(self.worker_forwards):
if rank == 0:
init_tasks.append(
self.async_parallel_forward(
async_forward,
pickle.dumps(input_token_ids),
pickle.dumps(input_meta_data.to_rpc_param()),
pickle.dumps(self.generation_config_dict),
)
)
else:
init_tasks.append(
self.async_parallel_forward(
async_forward,
None,
None,
None,
)
)
ret = await asyncio.gather(*init_tasks)
return ret[0]
def step(self) -> List[str]:
batch = self.request_handler.schedule()
with self.t_prepare:
batch = self.request_handler.schedule()
input_token_ids, input_meta_data = self.prepare_input(batch)
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data))
input_token_ids, input_meta_data = self.prepare_input(batch)
with self.t_exe:
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data))
# update the request_handler
next_tokens = torch.tensor(next_tokens, dtype=torch.int)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences
def kill_workers(self):
"""
I don't find a good way to implicit invoke self.kill_workers
NOTE(@lry89757) Don't find a good way to implicit invoke self.kill_workers
"""
assert len(self.workers) != 0
for proc in self.worker_processes:

View File

@ -1,4 +1,6 @@
from typing import List, Tuple, Union
import pickle
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union
import rpyc
import torch
@ -51,6 +53,25 @@ class rpcWorkerService(rpyc.Service):
def exposed_init_dist_env(self, rank, world_size, master_address, master_port):
logger.info(f"init process group for rank {rank}")
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address)
self.rank = rank
self.profiling = False
self.profiler = (
torch.profiler.profile(
record_shapes=True,
with_stack=True,
with_modules=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
# schedule=torch.profiler.schedule(wait=0, repeat=1, active=1),
# on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode),
)
if self.profiling
else nullcontext()
)
logger.info(f"init process group done for rank {rank}")
def exposed_init_model(
@ -98,38 +119,53 @@ class rpcWorkerService(rpyc.Service):
logger.info("physical cache init over")
def exposed_execute_model_forward(
self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict
self,
input_token_ids_param: Optional[List[int]] = None,
input_meta_data_param: Optional[dict] = None,
generation_config_param: Optional[dict] = None,
):
# prepare the data for model forward
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
if input_meta_data.is_prompts:
n_tokens = input_meta_data.sequence_lengths.sum().item()
else:
n_tokens = input_meta_data.batch_size
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
with self.profiler:
# prepare the data for model forward
input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers(
input_token_ids_param=input_token_ids_param,
input_meta_data_param=input_meta_data_param,
generation_config_param=generation_config_param,
)
# execute the model
logits = self.model(
input_token_ids,
self.output_tensor[:n_tokens],
input_meta_data,
self.k_cache,
self.v_cache,
)
if input_meta_data.is_prompts:
n_tokens = input_meta_data.sequence_lengths.sum().item()
else:
n_tokens = input_meta_data.batch_size
# sampler
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = search_tokens(
generation_config_param,
logits,
input_meta_data.is_prompts,
input_meta_data.batch_token_ids,
)
# execute the model
logits = self.model(
input_token_ids,
self.output_tensor[:n_tokens],
input_meta_data,
self.k_cache,
self.v_cache,
)
# return the tokens generated to scheduler
return next_tokens.tolist()
if self.profiling:
self.profiler.step()
self.record()
if self.rank == 0:
# sampler
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = search_tokens(
generation_config,
logits,
input_meta_data.is_prompts,
input_meta_data.batch_token_ids,
)
# return the tokens generated to scheduler
# only rank 0 need to pass the data back
# to reduce the overhead of rpc param passing
return next_tokens.cpu()
def _init_output_tensor(self):
alloc_shape = (
@ -166,6 +202,85 @@ class rpcWorkerService(rpyc.Service):
self.fd_inter_tensor = fd_inter_tensor
def _broadcast_param_to_all_workers(
self,
input_token_ids_param: Optional[List[int]] = None,
input_meta_data_param: Optional[dict] = None,
generation_config_param: Optional[dict] = None,
):
if self.rank == 0:
input_token_ids_param = pickle.loads(input_token_ids_param)
input_meta_data_param = pickle.loads(input_meta_data_param)
generation_config_param = pickle.loads(generation_config_param)
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
generation_config = generation_config_param
if dist.get_world_size() > 1:
broadcast_list = {}
for k, v in input_meta_data_param.items():
if not isinstance(v, torch.Tensor):
broadcast_list[k] = v
# Pass the tensor shape and type in advance for
# other workers to prepare the empty tensor and async transport tensors
broadcast_list["block_tables"] = (
input_meta_data.block_tables.size(),
input_meta_data.block_tables.dtype,
)
broadcast_list["sequence_lengths"] = (
input_meta_data.sequence_lengths.size(),
input_meta_data.sequence_lengths.dtype,
)
broadcast_list["input_token_ids"] = (input_token_ids.size(), input_token_ids.dtype)
# Generation Config Param
broadcast_list["generation_config"] = generation_config_param
# send some meta data and some tensor shape
torch.distributed.broadcast_object_list([broadcast_list], src=self.rank)
# send the real tensor
torch.distributed.broadcast(input_meta_data.block_tables, src=self.rank)
torch.distributed.broadcast(input_meta_data.sequence_lengths, src=self.rank)
torch.distributed.broadcast(input_token_ids, src=self.rank)
else:
assert input_meta_data_param is None, "Input Must Be None"
# recv the meta data
recv_list = [None]
torch.distributed.broadcast_object_list(recv_list, src=0)
input_meta_data_param = recv_list[0]
generation_config = input_meta_data_param["generation_config"]
blocktable_shape, blocktable_type = input_meta_data_param["block_tables"]
blocktables = torch.empty(blocktable_shape, dtype=blocktable_type, device=self.device)
sequence_lengths_shape, sequence_lengths_type = input_meta_data_param["sequence_lengths"]
sequence_lengths = torch.empty(sequence_lengths_shape, dtype=sequence_lengths_type, device=self.device)
input_token_ids_shape, input_token_ids_type = input_meta_data_param["input_token_ids"]
input_token_ids = torch.empty(input_token_ids_shape, dtype=input_token_ids_type, device=self.device)
# recv the real tensor
async1 = torch.distributed.broadcast(blocktables, src=0, async_op=True)
async2 = torch.distributed.broadcast(sequence_lengths, src=0, async_op=True)
async3 = torch.distributed.broadcast(input_token_ids, src=0, async_op=True)
input_meta_data_param["sequence_lengths"] = sequence_lengths
input_meta_data_param["block_tables"] = blocktables
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
async1.wait()
async2.wait()
async3.wait()
return input_token_ids, input_meta_data, generation_config
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
"""
Shard model or/and Load weight
@ -306,3 +421,9 @@ class rpcWorkerService(rpyc.Service):
logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}")
return data.item()
def record(self):
if self.profiling:
file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json"
self.profiler.export_chrome_trace(file)
logger.info(f"trace has been saved into {file}")