From 01ca9b813308e22d27ba87323f744a80c5668575 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Tue, 30 Jul 2024 07:48:39 +0000 Subject: [PATCH] remove timer --- colossalai/inference/core/rpc_engine.py | 46 +------------- colossalai/inference/executor/rpc_worker.py | 60 +++++++----------- colossalai/inference/utils.py | 67 --------------------- 3 files changed, 23 insertions(+), 150 deletions(-) diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py index 1fb27e6c8..4677418a3 100644 --- a/colossalai/inference/core/rpc_engine.py +++ b/colossalai/inference/core/rpc_engine.py @@ -1,7 +1,5 @@ import asyncio -import concurrent import pickle -from contextlib import nullcontext from itertools import count from time import sleep from typing import List, Tuple, Union @@ -17,7 +15,7 @@ from transformers.configuration_utils import PretrainedConfig from colossalai.inference.batch_bucket import RPCBatchBucket from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.executor.rpc_worker import rpcWorkerService -from colossalai.inference.utils import Timer, find_available_ports +from colossalai.inference.utils import find_available_ports from colossalai.logging import get_dist_logger from colossalai.shardformer.policies.base_policy import Policy @@ -126,18 +124,8 @@ class RPCInferenceEngine(InferenceEngine): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.timer = False - self.t_prepare = Timer("[Timer] prepare the data 2") if self.timer else nullcontext() - self.t_exe = Timer("[Timer] execute rpc worker") if self.timer else nullcontext() - # self.t_sampler = Timer("[Timer] sampler time") - self.logger.info("engine init over ") - def __del__(self): - if self.timer: - del self.t_prepare - del self.t_exe - def _verify_args(self) -> None: """Verify the input args""" if not isinstance(self.inference_config, InferenceConfig): @@ -313,34 +301,6 @@ class RPCInferenceEngine(InferenceEngine): return ret[0] - def step_(self, input_token_ids, input_meta_data: InputMetaData): - assert len(self.workers) == self.tp_size, "init workers first" - init_tasks = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.workers)) as executor: - for rank, worker in enumerate(self.workers): - if rank == 0: - init_tasks.append( - executor.submit( - worker.execute_model_forward, - pickle.dumps(input_token_ids), - pickle.dumps(input_meta_data.to_rpc_param()), - pickle.dumps(self.generation_config_dict), - ) - ) - else: - init_tasks.append( - executor.submit( - worker.execute_model_forward, - None, - None, - None, - ) - ) - - concurrent.futures.wait(init_tasks) - results = [future.result() for future in init_tasks] - return results[0] - def step(self) -> List[str]: with self.t_prepare: batch = self.request_handler.schedule() @@ -350,8 +310,6 @@ class RPCInferenceEngine(InferenceEngine): with self.t_exe: # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data)) - # with self.t_exe: - # next_tokens = self.step_(input_token_ids, input_meta_data) # update the request_handler self.request_handler.append_next_tokens(next_tokens) @@ -360,7 +318,7 @@ class RPCInferenceEngine(InferenceEngine): def kill_workers(self): """ - I don't find a good way to implicit invoke self.kill_workers + NOTE(@lry89757) Don't find a good way to implicit invoke self.kill_workers """ assert len(self.workers) != 0 for proc in self.worker_processes: diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index 4e84ec8f0..85f5758ae 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -55,13 +55,6 @@ class rpcWorkerService(rpyc.Service): colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) self.rank = rank - # profiling only, remove later - self.timing = False - - self.t_prepare = Timer("[Timer] prepare the data 1") if self.timing else nullcontext() - self.t_exe = Timer("[Timer] execute the model forward") if self.timing else nullcontext() - self.t_sampler = Timer("[Timer] sampler time") if self.timing else nullcontext() - self.profiling = False self.profiler = ( torch.profiler.profile( @@ -133,12 +126,11 @@ class rpcWorkerService(rpyc.Service): ): with self.profiler: # prepare the data for model forward - with self.t_prepare: - input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers( - input_token_ids_param=input_token_ids_param, - input_meta_data_param=input_meta_data_param, - generation_config_param=generation_config_param, - ) + input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers( + input_token_ids_param=input_token_ids_param, + input_meta_data_param=input_meta_data_param, + generation_config_param=generation_config_param, + ) if input_meta_data.is_prompts: n_tokens = input_meta_data.sequence_lengths.sum().item() @@ -146,14 +138,13 @@ class rpcWorkerService(rpyc.Service): n_tokens = input_meta_data.batch_size # execute the model - with self.t_exe: - logits = self.model( - input_token_ids, - self.output_tensor[:n_tokens], - input_meta_data, - self.k_cache, - self.v_cache, - ) + logits = self.model( + input_token_ids, + self.output_tensor[:n_tokens], + input_meta_data, + self.k_cache, + self.v_cache, + ) if self.profiling: self.profiler.step() @@ -161,16 +152,15 @@ class rpcWorkerService(rpyc.Service): self.record() if self.rank == 0: - with self.t_sampler: - # sampler - if self.inference_config.pad_input: - logits = logits[:, -1, :] - next_tokens = search_tokens( - generation_config, - logits, - input_meta_data.is_prompts, - input_meta_data.batch_token_ids, - ) + # sampler + if self.inference_config.pad_input: + logits = logits[:, -1, :] + next_tokens = search_tokens( + generation_config, + logits, + input_meta_data.is_prompts, + input_meta_data.batch_token_ids, + ) # return the tokens generated to scheduler # only rank 0 need to pass the data back @@ -432,14 +422,6 @@ class rpcWorkerService(rpyc.Service): return data.item() - def __del__(self): - """ - profiling only, remove later - """ - del self.t_prepare - del self.t_exe - del self.t_sampler - def record(self): if self.profiling: file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json" diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index c7ff9a6a0..d0851e362 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -194,70 +194,3 @@ def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]): """ else: return ModelType.UNKNOWN - - -""" -below just for profiling temporarily, will removed before merge -""" -import time -from contextlib import asynccontextmanager, contextmanager - - -@contextmanager -def timer(name=""): - # (@lry89757) will remove later - start_time = time.time() - try: - yield - finally: - end_time = time.time() - elapsed_time = end_time - start_time - print(f"{name} took {elapsed_time:.6f} seconds") - - -class Timer: - # (@lry89757) will remove later - def __init__(self, name=""): - print(f"init timer, {name}") - self.name = name - self.times = [] - - def __enter__(self): - self.start_time = time.time() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - end_time = time.time() - elapsed_time = end_time - self.start_time - self.times.append(elapsed_time) - # print(f"{self.name} took {elapsed_time:.6f} seconds") - # self.print_info() - - def print_info(self): - average_prefill_time = self.times[0] - print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds") - if len(self.times) > 1: - average_decoding_time = sum(self.times[1:]) / len(self.times[1:]) - print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds") - - def __del__(self): - if self.times: - average_prefill_time = self.times[0] - print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds") - if len(self.times) > 1: - average_decoding_time = sum(self.times[1:]) / len(self.times[1:]) - print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds") - else: - print(f"{self.name} no timings recorded") - - -@asynccontextmanager -async def async_timer(name=""): - # (@lry89757) will remove later - start_time = time.time() - try: - yield - finally: - end_time = time.time() - elapsed_time = end_time - start_time - print(f"{name} took {elapsed_time:.6f} seconds")