mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 04:18:55 +00:00
dist runtime opt source
This commit is contained in:
parent
bd38fe6b91
commit
55a5dd9dcd
@ -521,3 +521,67 @@ class BatchBucket:
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})"
|
return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})"
|
||||||
|
|
||||||
|
|
||||||
|
class RPCBatchBucket(BatchBucket):
|
||||||
|
def __init__(self, *args, **argv):
|
||||||
|
self.is_rpc = True
|
||||||
|
super().__init__(*args, **argv)
|
||||||
|
|
||||||
|
# For compatibility
|
||||||
|
def get_1D_inputs(self) -> List[int]:
|
||||||
|
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
||||||
|
first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence
|
||||||
|
if first_seq.output_len == 0:
|
||||||
|
# Assume prefill stage
|
||||||
|
assert all(
|
||||||
|
seq.output_len == 0 for seq in self._sequences_dict.values()
|
||||||
|
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
||||||
|
out_li = []
|
||||||
|
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq: Sequence = self._sequences_dict[seq_id]
|
||||||
|
out_li.extend(seq.input_token_id)
|
||||||
|
return out_li
|
||||||
|
else:
|
||||||
|
# Assume decoding stage
|
||||||
|
if self.use_spec_dec:
|
||||||
|
# For Speculative Decoding
|
||||||
|
# the number of tokens to be verified in parallel plus the correct token in the last step
|
||||||
|
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
|
||||||
|
assert all(
|
||||||
|
seq.output_len > 0 for seq in self._sequences_dict.values()
|
||||||
|
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
||||||
|
assert self.is_compact, "BatchBucket is not compact"
|
||||||
|
out = [0] * self.current_batch_size
|
||||||
|
for seq_id, index_in_b in self._sequences_indexes.items():
|
||||||
|
seq: Sequence = self._sequences_dict[seq_id]
|
||||||
|
out[index_in_b] = seq.output_token_id[-1]
|
||||||
|
return out
|
||||||
|
|
||||||
|
# For compatibility
|
||||||
|
def get_sequence_lengths(self) -> List[int]:
|
||||||
|
assert self.is_compact # Debug usage
|
||||||
|
sequence_lengths = self.seq_lengths[: self.current_batch_size]
|
||||||
|
return sequence_lengths
|
||||||
|
|
||||||
|
def get_1D_inputs_spec_dec(self, n: int) -> List[int]:
|
||||||
|
# Used for main model verification in **Decoding Stage**
|
||||||
|
# `n` is the number of tokens to be verified,
|
||||||
|
# and so that prepare the last `n` tokens of each sequence as the inputs
|
||||||
|
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
||||||
|
assert all(
|
||||||
|
seq.output_len >= n for seq in self._sequences_dict.values()
|
||||||
|
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
|
||||||
|
out_li = []
|
||||||
|
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq: Sequence = self._sequences_dict[seq_id]
|
||||||
|
out_li.extend(seq.output_token_id[-n:])
|
||||||
|
return out_li
|
||||||
|
|
||||||
|
# For compatibility
|
||||||
|
def get_block_table_tensor(self) -> torch.Tensor:
|
||||||
|
assert self.is_compact # Debug usage
|
||||||
|
block_table = self.block_tables[: self.current_batch_size]
|
||||||
|
return block_table
|
||||||
|
@ -87,8 +87,12 @@ class InputMetaData(RPC_PARAM):
|
|||||||
|
|
||||||
def to_rpc_param(self) -> Dict[str, any]:
|
def to_rpc_param(self) -> Dict[str, any]:
|
||||||
return {
|
return {
|
||||||
"block_tables": self.block_tables.tolist(),
|
"block_tables": self.block_tables.tolist()
|
||||||
"sequence_lengths": self.sequence_lengths.tolist(),
|
if isinstance(self.block_tables, torch.Tensor)
|
||||||
|
else self.block_tables,
|
||||||
|
"sequence_lengths": self.sequence_lengths.tolist()
|
||||||
|
if isinstance(self.block_tables, torch.Tensor)
|
||||||
|
else self.sequence_lengths,
|
||||||
"batch_size": self.batch_size,
|
"batch_size": self.batch_size,
|
||||||
"is_prompts": self.is_prompts,
|
"is_prompts": self.is_prompts,
|
||||||
"use_cuda_kernel": self.use_cuda_kernel,
|
"use_cuda_kernel": self.use_cuda_kernel,
|
||||||
@ -113,10 +117,14 @@ class InputMetaData(RPC_PARAM):
|
|||||||
return InputMetaData(
|
return InputMetaData(
|
||||||
block_tables=torch.tensor(
|
block_tables=torch.tensor(
|
||||||
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
|
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
|
||||||
),
|
)
|
||||||
|
if isinstance(rpc_dict["block_tables"], list)
|
||||||
|
else rpc_dict["block_tables"],
|
||||||
sequence_lengths=torch.tensor(
|
sequence_lengths=torch.tensor(
|
||||||
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
|
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
|
||||||
),
|
)
|
||||||
|
if isinstance(rpc_dict["sequence_lengths"], list)
|
||||||
|
else rpc_dict["sequence_lengths"],
|
||||||
batch_size=rpc_dict["batch_size"],
|
batch_size=rpc_dict["batch_size"],
|
||||||
is_prompts=rpc_dict["is_prompts"],
|
is_prompts=rpc_dict["is_prompts"],
|
||||||
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
|
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
from colossalai.inference.batch_bucket import BatchBucket
|
from colossalai.inference.batch_bucket import BatchBucket, RPCBatchBucket
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
||||||
@ -376,7 +376,7 @@ class RPCRequestHandler(RequestHandler):
|
|||||||
|
|
||||||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||||
# which may cause bugs and this issue should be fixed later.
|
# which may cause bugs and this issue should be fixed later.
|
||||||
self.running_bb = BatchBucket(
|
self.running_bb = RPCBatchBucket(
|
||||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
max_batch_size=self.max_batch_size,
|
max_batch_size=self.max_batch_size,
|
||||||
@ -386,7 +386,7 @@ class RPCRequestHandler(RequestHandler):
|
|||||||
fd_interm_tensor=None,
|
fd_interm_tensor=None,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.prefill_bb = BatchBucket(
|
self.prefill_bb = RPCBatchBucket(
|
||||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
max_batch_size=self.max_batch_size,
|
max_batch_size=self.max_batch_size,
|
||||||
|
@ -11,7 +11,7 @@ from torch import multiprocessing as mp
|
|||||||
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from colossalai.inference.batch_bucket import BatchBucket
|
from colossalai.inference.batch_bucket import RPCBatchBucket
|
||||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||||
from colossalai.inference.executor.rpc_worker import rpcWorkerService
|
from colossalai.inference.executor.rpc_worker import rpcWorkerService
|
||||||
from colossalai.inference.utils import find_available_ports
|
from colossalai.inference.utils import find_available_ports
|
||||||
@ -161,8 +161,16 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
raise Exception("conn error!")
|
raise Exception("conn error!")
|
||||||
self.logger.info(f"Build RPC Connection Success! Begin to load model...")
|
self.logger.info(f"Build RPC Connection Success! Begin to load model...")
|
||||||
asyncio.run(self.init_worker_env())
|
asyncio.run(self.init_worker_env())
|
||||||
|
self._init_worker_forward()
|
||||||
self.logger.info(f"init dist env over")
|
self.logger.info(f"init dist env over")
|
||||||
|
|
||||||
|
def _init_worker_forward(self):
|
||||||
|
"""
|
||||||
|
Async wrappers for forward, because it will be invoked many times.
|
||||||
|
"""
|
||||||
|
assert len(self.workers) == self.tp_size, "init workers first"
|
||||||
|
self.worker_forwards = [rpyc.async_(worker.execute_model_forward) for worker in self.workers]
|
||||||
|
|
||||||
async def async_parallel_wrapper(self, f, *args, **kwargs):
|
async def async_parallel_wrapper(self, f, *args, **kwargs):
|
||||||
async_res = rpyc.async_(f)(*args, **kwargs)
|
async_res = rpyc.async_(f)(*args, **kwargs)
|
||||||
await asyncio.to_thread(async_res.wait)
|
await asyncio.to_thread(async_res.wait)
|
||||||
@ -209,7 +217,8 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
|
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
|
||||||
asyncio.run(self._init_device_cache(alloc_shape))
|
asyncio.run(self._init_device_cache(alloc_shape))
|
||||||
|
|
||||||
def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]:
|
def prepare_input(self, batch: RPCBatchBucket) -> Tuple[List[int], InputMetaData]:
|
||||||
|
assert batch.is_rpc, "the batch must be RPCBatchBucket"
|
||||||
input_ids = batch.get_1D_inputs()
|
input_ids = batch.get_1D_inputs()
|
||||||
sequence_lengths = batch.get_sequence_lengths()
|
sequence_lengths = batch.get_sequence_lengths()
|
||||||
|
|
||||||
@ -219,7 +228,7 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
n_tokens = batch.current_batch_size
|
n_tokens = batch.current_batch_size
|
||||||
if batch.use_spec_dec:
|
if batch.use_spec_dec:
|
||||||
n_tokens = batch.num_tokens_to_verify + 1
|
n_tokens = batch.num_tokens_to_verify + 1
|
||||||
assert n_tokens == input_ids.size(0)
|
assert n_tokens == len(input_ids)
|
||||||
n_tokens = n_tokens * batch.current_batch_size
|
n_tokens = n_tokens * batch.current_batch_size
|
||||||
|
|
||||||
batch_token_ids = None
|
batch_token_ids = None
|
||||||
@ -251,20 +260,38 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
batch_token_ids=batch_token_ids,
|
batch_token_ids=batch_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
return input_ids.tolist(), input_meta_data
|
return input_ids, input_meta_data
|
||||||
|
|
||||||
|
async def async_parallel_forward(self, async_f, *args, **kwargs):
|
||||||
|
async_res = async_f(*args, **kwargs)
|
||||||
|
await asyncio.to_thread(async_res.wait)
|
||||||
|
assert async_res.ready
|
||||||
|
return async_res.value
|
||||||
|
|
||||||
async def step_(self, input_token_ids, input_meta_data: InputMetaData):
|
async def step_(self, input_token_ids, input_meta_data: InputMetaData):
|
||||||
assert len(self.workers) == self.tp_size, "init workers first"
|
assert len(self.workers) == self.tp_size, "init workers first"
|
||||||
|
|
||||||
init_tasks = [
|
init_tasks = []
|
||||||
self.async_parallel_wrapper(
|
for rank, async_forward in enumerate(self.worker_forwards):
|
||||||
worker.execute_model_forward,
|
if rank == 0:
|
||||||
input_token_ids,
|
init_tasks.append(
|
||||||
input_meta_data.to_rpc_param(),
|
self.async_parallel_forward(
|
||||||
self.generation_config_dict,
|
async_forward,
|
||||||
)
|
input_token_ids,
|
||||||
for worker in self.workers
|
input_meta_data.to_rpc_param(),
|
||||||
]
|
self.generation_config_dict,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
init_tasks.append(
|
||||||
|
self.async_parallel_forward(
|
||||||
|
async_forward,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
ret = await asyncio.gather(*init_tasks)
|
ret = await asyncio.gather(*init_tasks)
|
||||||
|
|
||||||
return ret[0]
|
return ret[0]
|
||||||
@ -277,7 +304,6 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data))
|
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data))
|
||||||
|
|
||||||
# update the request_handler
|
# update the request_handler
|
||||||
next_tokens = torch.tensor(next_tokens, dtype=torch.int)
|
|
||||||
self.request_handler.append_next_tokens(next_tokens)
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
return finished_sequences
|
return finished_sequences
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import rpyc
|
import rpyc
|
||||||
import torch
|
import torch
|
||||||
@ -18,7 +18,7 @@ from colossalai.inference.modeling.policy import (
|
|||||||
model_policy_map,
|
model_policy_map,
|
||||||
)
|
)
|
||||||
from colossalai.inference.sampler import search_tokens
|
from colossalai.inference.sampler import search_tokens
|
||||||
from colossalai.inference.utils import get_model_size
|
from colossalai.inference.utils import Timer, get_model_size
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
@ -51,6 +51,12 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
def exposed_init_dist_env(self, rank, world_size, master_address, master_port):
|
def exposed_init_dist_env(self, rank, world_size, master_address, master_port):
|
||||||
logger.info(f"init process group for rank {rank}")
|
logger.info(f"init process group for rank {rank}")
|
||||||
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address)
|
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address)
|
||||||
|
self.rank = rank
|
||||||
|
|
||||||
|
# profiling only, remove later
|
||||||
|
self.t_prepare = Timer("[Timer] prepare the data")
|
||||||
|
self.t_exe = Timer("[Timer] execute the model forward")
|
||||||
|
self.t_sampler = Timer("[Timer] sampler time")
|
||||||
logger.info(f"init process group done for rank {rank}")
|
logger.info(f"init process group done for rank {rank}")
|
||||||
|
|
||||||
def exposed_init_model(
|
def exposed_init_model(
|
||||||
@ -98,38 +104,50 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
logger.info("physical cache init over")
|
logger.info("physical cache init over")
|
||||||
|
|
||||||
def exposed_execute_model_forward(
|
def exposed_execute_model_forward(
|
||||||
self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict
|
self,
|
||||||
|
input_token_ids_param: Optional[List[int]] = None,
|
||||||
|
input_meta_data_param: Optional[dict] = None,
|
||||||
|
generation_config_param: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
# prepare the data for model forward
|
# prepare the data for model forward
|
||||||
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
with self.t_prepare:
|
||||||
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers(
|
||||||
|
input_token_ids_param=input_token_ids_param,
|
||||||
|
input_meta_data_param=input_meta_data_param,
|
||||||
|
generation_config_param=generation_config_param,
|
||||||
|
)
|
||||||
|
|
||||||
if input_meta_data.is_prompts:
|
if input_meta_data.is_prompts:
|
||||||
n_tokens = input_meta_data.sequence_lengths.sum().item()
|
n_tokens = input_meta_data.sequence_lengths.sum().item()
|
||||||
else:
|
else:
|
||||||
n_tokens = input_meta_data.batch_size
|
n_tokens = input_meta_data.batch_size
|
||||||
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
|
|
||||||
|
|
||||||
# execute the model
|
# execute the model
|
||||||
logits = self.model(
|
with self.t_exe:
|
||||||
input_token_ids,
|
logits = self.model(
|
||||||
self.output_tensor[:n_tokens],
|
input_token_ids,
|
||||||
input_meta_data,
|
self.output_tensor[:n_tokens],
|
||||||
self.k_cache,
|
input_meta_data,
|
||||||
self.v_cache,
|
self.k_cache,
|
||||||
)
|
self.v_cache,
|
||||||
|
)
|
||||||
|
|
||||||
# sampler
|
if self.rank == 0:
|
||||||
if self.inference_config.pad_input:
|
with self.t_sampler:
|
||||||
logits = logits[:, -1, :]
|
# sampler
|
||||||
next_tokens = search_tokens(
|
if self.inference_config.pad_input:
|
||||||
generation_config_param,
|
logits = logits[:, -1, :]
|
||||||
logits,
|
next_tokens = search_tokens(
|
||||||
input_meta_data.is_prompts,
|
generation_config,
|
||||||
input_meta_data.batch_token_ids,
|
logits,
|
||||||
)
|
input_meta_data.is_prompts,
|
||||||
|
input_meta_data.batch_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
# return the tokens generated to scheduler
|
# return the tokens generated to scheduler
|
||||||
return next_tokens.tolist()
|
# only rank 0 need to pass the data back
|
||||||
|
# to reduce the overhead of rpc param passing
|
||||||
|
return next_tokens.cpu()
|
||||||
|
|
||||||
def _init_output_tensor(self):
|
def _init_output_tensor(self):
|
||||||
alloc_shape = (
|
alloc_shape = (
|
||||||
@ -166,6 +184,84 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
|
|
||||||
self.fd_inter_tensor = fd_inter_tensor
|
self.fd_inter_tensor = fd_inter_tensor
|
||||||
|
|
||||||
|
def _broadcast_param_to_all_workers(
|
||||||
|
self,
|
||||||
|
input_token_ids_param: Optional[List[int]] = None,
|
||||||
|
input_meta_data_param: Optional[dict] = None,
|
||||||
|
generation_config_param: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
if self.rank == 0:
|
||||||
|
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
||||||
|
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
||||||
|
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
|
||||||
|
generation_config = generation_config_param
|
||||||
|
|
||||||
|
if dist.get_world_size() > 1:
|
||||||
|
broadcast_list = {}
|
||||||
|
for k, v in input_meta_data_param.items():
|
||||||
|
if not isinstance(v, List):
|
||||||
|
broadcast_list[k] = v
|
||||||
|
|
||||||
|
# Pass the tensor shape and type in advance for
|
||||||
|
# other workers to prepare the empty tensor and async transport tensors
|
||||||
|
broadcast_list["block_tables"] = (
|
||||||
|
input_meta_data.block_tables.size(),
|
||||||
|
input_meta_data.block_tables.dtype,
|
||||||
|
)
|
||||||
|
broadcast_list["sequence_lengths"] = (
|
||||||
|
input_meta_data.sequence_lengths.size(),
|
||||||
|
input_meta_data.sequence_lengths.dtype,
|
||||||
|
)
|
||||||
|
broadcast_list["input_token_ids"] = (input_token_ids.size(), input_token_ids.dtype)
|
||||||
|
|
||||||
|
# Generation Config Param
|
||||||
|
broadcast_list["generation_config"] = generation_config_param
|
||||||
|
|
||||||
|
# send some meta data and some tensor shape
|
||||||
|
torch.distributed.broadcast_object_list([broadcast_list], src=self.rank)
|
||||||
|
|
||||||
|
# send the real tensor
|
||||||
|
torch.distributed.broadcast(input_meta_data.block_tables, src=self.rank)
|
||||||
|
torch.distributed.broadcast(input_meta_data.sequence_lengths, src=self.rank)
|
||||||
|
torch.distributed.broadcast(input_token_ids, src=self.rank)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert input_meta_data_param is None, "Input Must Be None"
|
||||||
|
|
||||||
|
# recv the meta data
|
||||||
|
recv_list = [None]
|
||||||
|
torch.distributed.broadcast_object_list(recv_list, src=0)
|
||||||
|
input_meta_data_param = recv_list[0]
|
||||||
|
|
||||||
|
generation_config = input_meta_data_param["generation_config"]
|
||||||
|
|
||||||
|
blocktable_shape, blocktable_type = input_meta_data_param["block_tables"]
|
||||||
|
blocktables = torch.empty(blocktable_shape, dtype=blocktable_type, device=self.device)
|
||||||
|
sequence_lengths_shape, sequence_lengths_type = input_meta_data_param["sequence_lengths"]
|
||||||
|
sequence_lengths = torch.empty(sequence_lengths_shape, dtype=sequence_lengths_type, device=self.device)
|
||||||
|
input_token_ids_shape, input_token_ids_type = input_meta_data_param["input_token_ids"]
|
||||||
|
input_token_ids = torch.empty(input_token_ids_shape, dtype=input_token_ids_type, device=self.device)
|
||||||
|
|
||||||
|
# recv the real tensor
|
||||||
|
async1 = torch.distributed.broadcast(blocktables, src=0, async_op=True)
|
||||||
|
async2 = torch.distributed.broadcast(sequence_lengths, src=0, async_op=True)
|
||||||
|
async3 = torch.distributed.broadcast(input_token_ids, src=0, async_op=True)
|
||||||
|
|
||||||
|
input_meta_data_param["sequence_lengths"] = sequence_lengths
|
||||||
|
input_meta_data_param["blocktables"] = blocktables
|
||||||
|
|
||||||
|
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
||||||
|
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
||||||
|
|
||||||
|
async1.wait()
|
||||||
|
async2.wait()
|
||||||
|
async3.wait()
|
||||||
|
|
||||||
|
input_meta_data.block_tables = blocktables
|
||||||
|
input_meta_data.sequence_lengths = sequence_lengths
|
||||||
|
|
||||||
|
return input_token_ids, input_meta_data, generation_config
|
||||||
|
|
||||||
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
||||||
"""
|
"""
|
||||||
Shard model or/and Load weight
|
Shard model or/and Load weight
|
||||||
@ -304,3 +400,11 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}")
|
logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}")
|
||||||
|
|
||||||
return data.item()
|
return data.item()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
profiling only, remove later
|
||||||
|
"""
|
||||||
|
del self.t_prepare
|
||||||
|
del self.t_exe
|
||||||
|
del self.t_sampler
|
||||||
|
@ -113,3 +113,70 @@ def find_available_ports(num: int):
|
|||||||
print(f"An OS error occurred: {e}")
|
print(f"An OS error occurred: {e}")
|
||||||
raise RuntimeError("Error finding available ports")
|
raise RuntimeError("Error finding available ports")
|
||||||
return free_ports
|
return free_ports
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
below just for profiling temporarily, will removed before merge
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def timer(name=""):
|
||||||
|
# (@lry89757) will remove later
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
print(f"{name} took {elapsed_time:.6f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
# (@lry89757) will remove later
|
||||||
|
def __init__(self, name=""):
|
||||||
|
print(f"init timer, {name}")
|
||||||
|
self.name = name
|
||||||
|
self.times = []
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - self.start_time
|
||||||
|
self.times.append(elapsed_time)
|
||||||
|
print(f"{self.name} took {elapsed_time:.6f} seconds")
|
||||||
|
self.print_info()
|
||||||
|
|
||||||
|
def print_info(self):
|
||||||
|
average_prefill_time = self.times[0]
|
||||||
|
print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds")
|
||||||
|
if len(self.times) > 1:
|
||||||
|
average_decoding_time = sum(self.times[1:]) / len(self.times[1:])
|
||||||
|
print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.times:
|
||||||
|
average_prefill_time = self.times[0]
|
||||||
|
print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds")
|
||||||
|
if len(self.times) > 1:
|
||||||
|
average_decoding_time = sum(self.times[1:]) / len(self.times[1:])
|
||||||
|
print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds")
|
||||||
|
else:
|
||||||
|
print(f"{self.name} no timings recorded")
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def async_timer(name=""):
|
||||||
|
# (@lry89757) will remove later
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
print(f"{name} took {elapsed_time:.6f} seconds")
|
||||||
|
Loading…
Reference in New Issue
Block a user