tmp save for profiling

This commit is contained in:
Runyu Lu 2024-06-05 04:02:22 +00:00
parent 55a5dd9dcd
commit 509f3a62ab
6 changed files with 205 additions and 70 deletions

View File

@ -526,6 +526,7 @@ class BatchBucket:
class RPCBatchBucket(BatchBucket): class RPCBatchBucket(BatchBucket):
def __init__(self, *args, **argv): def __init__(self, *args, **argv):
self.is_rpc = True self.is_rpc = True
self.device = "cpu"
super().__init__(*args, **argv) super().__init__(*args, **argv)
# For compatibility # For compatibility

View File

@ -87,12 +87,14 @@ class InputMetaData(RPC_PARAM):
def to_rpc_param(self) -> Dict[str, any]: def to_rpc_param(self) -> Dict[str, any]:
return { return {
"block_tables": self.block_tables.tolist() "block_tables": self.block_tables,
if isinstance(self.block_tables, torch.Tensor) # "block_tables": self.block_tables.tolist()
else self.block_tables, # if isinstance(self.block_tables, torch.Tensor)
"sequence_lengths": self.sequence_lengths.tolist() # else self.block_tables,
if isinstance(self.block_tables, torch.Tensor) "sequence_lengths": self.sequence_lengths,
else self.sequence_lengths, # "sequence_lengths": self.sequence_lengths.tolist()
# if isinstance(self.block_tables, torch.Tensor)
# else self.sequence_lengths,
"batch_size": self.batch_size, "batch_size": self.batch_size,
"is_prompts": self.is_prompts, "is_prompts": self.is_prompts,
"use_cuda_kernel": self.use_cuda_kernel, "use_cuda_kernel": self.use_cuda_kernel,
@ -114,17 +116,14 @@ class InputMetaData(RPC_PARAM):
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
dtype = getattr(torch, rpc_dict["dtype"]) dtype = getattr(torch, rpc_dict["dtype"])
device = get_accelerator().get_current_device()
return InputMetaData( return InputMetaData(
block_tables=torch.tensor( block_tables=torch.tensor(rpc_dict["block_tables"], dtype=torch.int, device=device)
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
)
if isinstance(rpc_dict["block_tables"], list) if isinstance(rpc_dict["block_tables"], list)
else rpc_dict["block_tables"], else rpc_dict["block_tables"].to(device),
sequence_lengths=torch.tensor( sequence_lengths=torch.tensor(rpc_dict["sequence_lengths"], dtype=torch.int, device=device)
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
)
if isinstance(rpc_dict["sequence_lengths"], list) if isinstance(rpc_dict["sequence_lengths"], list)
else rpc_dict["sequence_lengths"], else rpc_dict["sequence_lengths"].to(device),
batch_size=rpc_dict["batch_size"], batch_size=rpc_dict["batch_size"],
is_prompts=rpc_dict["is_prompts"], is_prompts=rpc_dict["is_prompts"],
use_cuda_kernel=rpc_dict["use_cuda_kernel"], use_cuda_kernel=rpc_dict["use_cuda_kernel"],

View File

@ -1,4 +1,5 @@
import time import time
from contextlib import nullcontext
from itertools import count from itertools import count
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, List, Optional, Tuple, Type, Union
@ -24,7 +25,7 @@ from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size from colossalai.inference.utils import Timer, get_model_size
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -103,6 +104,30 @@ class InferenceEngine:
self.use_glide = False self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens self.n_spec_tokens = self.inference_config.max_n_spec_tokens
# profiling only, remove later
self.timing = False
self.t_prepare = Timer("[Timer] prepare the data 1") if self.timing else nullcontext()
self.t_exe = Timer("[Timer] execute the model forward") if self.timing else nullcontext()
self.t_sampler = Timer("[Timer] sampler time") if self.timing else nullcontext()
self.profiling = False
self.profiler = (
torch.profiler.profile(
record_shapes=True,
with_stack=True,
with_modules=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
# schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
# on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode),
)
if self.profiling
else nullcontext()
)
self._verify_args() self._verify_args()
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
@ -517,6 +542,7 @@ class InferenceEngine:
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False, return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
step_list: Optional[List[int]] = None,
) -> List[str]: ) -> List[str]:
""" """
Executing the inference step. Executing the inference step.
@ -559,7 +585,11 @@ class InferenceEngine:
output_seqs_list += self.steps_spec_dec() output_seqs_list += self.steps_spec_dec()
else: else:
while self.request_handler.check_unfinished_seqs(): while self.request_handler.check_unfinished_seqs():
a = time.perf_counter()
output_seqs_list += self.step() output_seqs_list += self.step()
b = time.perf_counter()
if isinstance(step_list, list):
step_list.append(b - a)
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
@ -574,6 +604,19 @@ class InferenceEngine:
else: else:
return output_str return output_str
def __del__(self):
if self.timing:
del self.t_prepare
del self.t_exe
del self.t_sampler
self.record()
def record(self):
if self.profiling:
file = "/home/lurunyu/projects/ColossalAI/test_trace_non_rpc.json"
self.profiler.export_chrome_trace(file)
self.logger.info(f"trace has been saved into {file}")
@property @property
def has_prompt_template(self) -> bool: def has_prompt_template(self) -> bool:
""" """ """ """
@ -741,6 +784,8 @@ class InferenceEngine:
List[str]: Decoded finished sequences generated by one step. List[str]: Decoded finished sequences generated by one step.
""" """
with self.profiler:
with self.t_prepare:
batch = self.request_handler.schedule() batch = self.request_handler.schedule()
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
@ -750,12 +795,18 @@ class InferenceEngine:
else: else:
model_executable = self.model model_executable = self.model
with self.t_exe:
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
with self.t_sampler:
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
next_tokens = search_tokens( next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids self.generation_config,
logits,
input_meta_data.is_prompts,
batch_token_ids=input_meta_data.batch_token_ids,
) )
self.request_handler.append_next_tokens(next_tokens) self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()

View File

@ -1,4 +1,7 @@
import asyncio import asyncio
import concurrent
import pickle
from contextlib import nullcontext
from itertools import count from itertools import count
from time import sleep from time import sleep
from typing import List, Tuple, Union from typing import List, Tuple, Union
@ -14,7 +17,7 @@ from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.batch_bucket import RPCBatchBucket from colossalai.inference.batch_bucket import RPCBatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.executor.rpc_worker import rpcWorkerService from colossalai.inference.executor.rpc_worker import rpcWorkerService
from colossalai.inference.utils import find_available_ports from colossalai.inference.utils import Timer, find_available_ports
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
@ -119,8 +122,21 @@ class RPCInferenceEngine(InferenceEngine):
self.counter = count() self.counter = count()
self._verify_args() self._verify_args()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.timer = False
self.t_prepare = Timer("[Timer] prepare the data 2") if self.timer else nullcontext()
self.t_exe = Timer("[Timer] execute rpc worker") if self.timer else nullcontext()
# self.t_sampler = Timer("[Timer] sampler time")
self.logger.info("engine init over ") self.logger.info("engine init over ")
def __del__(self):
if self.timer:
del self.t_prepare
del self.t_exe
def _verify_args(self) -> None: def _verify_args(self) -> None:
"""Verify the input args""" """Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig): if not isinstance(self.inference_config, InferenceConfig):
@ -268,7 +284,7 @@ class RPCInferenceEngine(InferenceEngine):
assert async_res.ready assert async_res.ready
return async_res.value return async_res.value
async def step_(self, input_token_ids, input_meta_data: InputMetaData): async def step_async(self, input_token_ids, input_meta_data: InputMetaData):
assert len(self.workers) == self.tp_size, "init workers first" assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [] init_tasks = []
@ -277,9 +293,9 @@ class RPCInferenceEngine(InferenceEngine):
init_tasks.append( init_tasks.append(
self.async_parallel_forward( self.async_parallel_forward(
async_forward, async_forward,
input_token_ids, pickle.dumps(input_token_ids),
input_meta_data.to_rpc_param(), pickle.dumps(input_meta_data.to_rpc_param()),
self.generation_config_dict, pickle.dumps(self.generation_config_dict),
) )
) )
else: else:
@ -296,12 +312,45 @@ class RPCInferenceEngine(InferenceEngine):
return ret[0] return ret[0]
def step_(self, input_token_ids, input_meta_data: InputMetaData):
assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.workers)) as executor:
for rank, worker in enumerate(self.workers):
if rank == 0:
init_tasks.append(
executor.submit(
worker.execute_model_forward,
pickle.dumps(input_token_ids),
pickle.dumps(input_meta_data.to_rpc_param()),
pickle.dumps(self.generation_config_dict),
)
)
else:
init_tasks.append(
executor.submit(
worker.execute_model_forward,
None,
None,
None,
)
)
concurrent.futures.wait(init_tasks)
results = [future.result() for future in init_tasks]
return results[0]
def step(self) -> List[str]: def step(self) -> List[str]:
with self.t_prepare:
batch = self.request_handler.schedule() batch = self.request_handler.schedule()
input_token_ids, input_meta_data = self.prepare_input(batch) input_token_ids, input_meta_data = self.prepare_input(batch)
with self.t_exe:
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data)) next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data))
# with self.t_exe:
# next_tokens = self.step_(input_token_ids, input_meta_data)
# update the request_handler # update the request_handler
self.request_handler.append_next_tokens(next_tokens) self.request_handler.append_next_tokens(next_tokens)

View File

@ -1,3 +1,5 @@
import pickle
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import rpyc import rpyc
@ -54,9 +56,29 @@ class rpcWorkerService(rpyc.Service):
self.rank = rank self.rank = rank
# profiling only, remove later # profiling only, remove later
self.t_prepare = Timer("[Timer] prepare the data") self.timing = False
self.t_exe = Timer("[Timer] execute the model forward")
self.t_sampler = Timer("[Timer] sampler time") self.t_prepare = Timer("[Timer] prepare the data 1") if self.timing else nullcontext()
self.t_exe = Timer("[Timer] execute the model forward") if self.timing else nullcontext()
self.t_sampler = Timer("[Timer] sampler time") if self.timing else nullcontext()
self.profiling = False
self.profiler = (
torch.profiler.profile(
record_shapes=True,
with_stack=True,
with_modules=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
# schedule=torch.profiler.schedule(wait=0, repeat=1, active=1),
# on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode),
)
if self.profiling
else nullcontext()
)
logger.info(f"init process group done for rank {rank}") logger.info(f"init process group done for rank {rank}")
def exposed_init_model( def exposed_init_model(
@ -109,6 +131,7 @@ class rpcWorkerService(rpyc.Service):
input_meta_data_param: Optional[dict] = None, input_meta_data_param: Optional[dict] = None,
generation_config_param: Optional[dict] = None, generation_config_param: Optional[dict] = None,
): ):
with self.profiler:
# prepare the data for model forward # prepare the data for model forward
with self.t_prepare: with self.t_prepare:
input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers( input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers(
@ -132,6 +155,11 @@ class rpcWorkerService(rpyc.Service):
self.v_cache, self.v_cache,
) )
if self.profiling:
self.profiler.step()
self.record()
if self.rank == 0: if self.rank == 0:
with self.t_sampler: with self.t_sampler:
# sampler # sampler
@ -191,6 +219,10 @@ class rpcWorkerService(rpyc.Service):
generation_config_param: Optional[dict] = None, generation_config_param: Optional[dict] = None,
): ):
if self.rank == 0: if self.rank == 0:
input_token_ids_param = pickle.loads(input_token_ids_param)
input_meta_data_param = pickle.loads(input_meta_data_param)
generation_config_param = pickle.loads(generation_config_param)
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor input_meta_data.fd_inter_tensor = self.fd_inter_tensor
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
@ -199,7 +231,7 @@ class rpcWorkerService(rpyc.Service):
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
broadcast_list = {} broadcast_list = {}
for k, v in input_meta_data_param.items(): for k, v in input_meta_data_param.items():
if not isinstance(v, List): if not isinstance(v, torch.Tensor):
broadcast_list[k] = v broadcast_list[k] = v
# Pass the tensor shape and type in advance for # Pass the tensor shape and type in advance for
@ -248,7 +280,7 @@ class rpcWorkerService(rpyc.Service):
async3 = torch.distributed.broadcast(input_token_ids, src=0, async_op=True) async3 = torch.distributed.broadcast(input_token_ids, src=0, async_op=True)
input_meta_data_param["sequence_lengths"] = sequence_lengths input_meta_data_param["sequence_lengths"] = sequence_lengths
input_meta_data_param["blocktables"] = blocktables input_meta_data_param["block_tables"] = blocktables
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor input_meta_data.fd_inter_tensor = self.fd_inter_tensor
@ -257,9 +289,6 @@ class rpcWorkerService(rpyc.Service):
async2.wait() async2.wait()
async3.wait() async3.wait()
input_meta_data.block_tables = blocktables
input_meta_data.sequence_lengths = sequence_lengths
return input_token_ids, input_meta_data, generation_config return input_token_ids, input_meta_data, generation_config
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
@ -408,3 +437,9 @@ class rpcWorkerService(rpyc.Service):
del self.t_prepare del self.t_prepare
del self.t_exe del self.t_exe
del self.t_sampler del self.t_sampler
def record(self):
if self.profiling:
file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json"
self.profiler.export_chrome_trace(file)
logger.info(f"trace has been saved into {file}")

View File

@ -149,8 +149,8 @@ class Timer:
end_time = time.time() end_time = time.time()
elapsed_time = end_time - self.start_time elapsed_time = end_time - self.start_time
self.times.append(elapsed_time) self.times.append(elapsed_time)
print(f"{self.name} took {elapsed_time:.6f} seconds") # print(f"{self.name} took {elapsed_time:.6f} seconds")
self.print_info() # self.print_info()
def print_info(self): def print_info(self):
average_prefill_time = self.times[0] average_prefill_time = self.times[0]