remove timer

This commit is contained in:
Runyu Lu 2024-07-30 07:48:39 +00:00
parent 6dcc127c42
commit 01ca9b8133
3 changed files with 23 additions and 150 deletions

View File

@ -1,7 +1,5 @@
import asyncio import asyncio
import concurrent
import pickle import pickle
from contextlib import nullcontext
from itertools import count from itertools import count
from time import sleep from time import sleep
from typing import List, Tuple, Union from typing import List, Tuple, Union
@ -17,7 +15,7 @@ from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.batch_bucket import RPCBatchBucket from colossalai.inference.batch_bucket import RPCBatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.executor.rpc_worker import rpcWorkerService from colossalai.inference.executor.rpc_worker import rpcWorkerService
from colossalai.inference.utils import Timer, find_available_ports from colossalai.inference.utils import find_available_ports
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
@ -126,18 +124,8 @@ class RPCInferenceEngine(InferenceEngine):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
self.timer = False
self.t_prepare = Timer("[Timer] prepare the data 2") if self.timer else nullcontext()
self.t_exe = Timer("[Timer] execute rpc worker") if self.timer else nullcontext()
# self.t_sampler = Timer("[Timer] sampler time")
self.logger.info("engine init over ") self.logger.info("engine init over ")
def __del__(self):
if self.timer:
del self.t_prepare
del self.t_exe
def _verify_args(self) -> None: def _verify_args(self) -> None:
"""Verify the input args""" """Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig): if not isinstance(self.inference_config, InferenceConfig):
@ -313,34 +301,6 @@ class RPCInferenceEngine(InferenceEngine):
return ret[0] return ret[0]
def step_(self, input_token_ids, input_meta_data: InputMetaData):
assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.workers)) as executor:
for rank, worker in enumerate(self.workers):
if rank == 0:
init_tasks.append(
executor.submit(
worker.execute_model_forward,
pickle.dumps(input_token_ids),
pickle.dumps(input_meta_data.to_rpc_param()),
pickle.dumps(self.generation_config_dict),
)
)
else:
init_tasks.append(
executor.submit(
worker.execute_model_forward,
None,
None,
None,
)
)
concurrent.futures.wait(init_tasks)
results = [future.result() for future in init_tasks]
return results[0]
def step(self) -> List[str]: def step(self) -> List[str]:
with self.t_prepare: with self.t_prepare:
batch = self.request_handler.schedule() batch = self.request_handler.schedule()
@ -350,8 +310,6 @@ class RPCInferenceEngine(InferenceEngine):
with self.t_exe: with self.t_exe:
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data)) next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data))
# with self.t_exe:
# next_tokens = self.step_(input_token_ids, input_meta_data)
# update the request_handler # update the request_handler
self.request_handler.append_next_tokens(next_tokens) self.request_handler.append_next_tokens(next_tokens)
@ -360,7 +318,7 @@ class RPCInferenceEngine(InferenceEngine):
def kill_workers(self): def kill_workers(self):
""" """
I don't find a good way to implicit invoke self.kill_workers NOTE(@lry89757) Don't find a good way to implicit invoke self.kill_workers
""" """
assert len(self.workers) != 0 assert len(self.workers) != 0
for proc in self.worker_processes: for proc in self.worker_processes:

View File

@ -55,13 +55,6 @@ class rpcWorkerService(rpyc.Service):
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address)
self.rank = rank self.rank = rank
# profiling only, remove later
self.timing = False
self.t_prepare = Timer("[Timer] prepare the data 1") if self.timing else nullcontext()
self.t_exe = Timer("[Timer] execute the model forward") if self.timing else nullcontext()
self.t_sampler = Timer("[Timer] sampler time") if self.timing else nullcontext()
self.profiling = False self.profiling = False
self.profiler = ( self.profiler = (
torch.profiler.profile( torch.profiler.profile(
@ -133,12 +126,11 @@ class rpcWorkerService(rpyc.Service):
): ):
with self.profiler: with self.profiler:
# prepare the data for model forward # prepare the data for model forward
with self.t_prepare: input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers(
input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers( input_token_ids_param=input_token_ids_param,
input_token_ids_param=input_token_ids_param, input_meta_data_param=input_meta_data_param,
input_meta_data_param=input_meta_data_param, generation_config_param=generation_config_param,
generation_config_param=generation_config_param, )
)
if input_meta_data.is_prompts: if input_meta_data.is_prompts:
n_tokens = input_meta_data.sequence_lengths.sum().item() n_tokens = input_meta_data.sequence_lengths.sum().item()
@ -146,14 +138,13 @@ class rpcWorkerService(rpyc.Service):
n_tokens = input_meta_data.batch_size n_tokens = input_meta_data.batch_size
# execute the model # execute the model
with self.t_exe: logits = self.model(
logits = self.model( input_token_ids,
input_token_ids, self.output_tensor[:n_tokens],
self.output_tensor[:n_tokens], input_meta_data,
input_meta_data, self.k_cache,
self.k_cache, self.v_cache,
self.v_cache, )
)
if self.profiling: if self.profiling:
self.profiler.step() self.profiler.step()
@ -161,16 +152,15 @@ class rpcWorkerService(rpyc.Service):
self.record() self.record()
if self.rank == 0: if self.rank == 0:
with self.t_sampler: # sampler
# sampler if self.inference_config.pad_input:
if self.inference_config.pad_input: logits = logits[:, -1, :]
logits = logits[:, -1, :] next_tokens = search_tokens(
next_tokens = search_tokens( generation_config,
generation_config, logits,
logits, input_meta_data.is_prompts,
input_meta_data.is_prompts, input_meta_data.batch_token_ids,
input_meta_data.batch_token_ids, )
)
# return the tokens generated to scheduler # return the tokens generated to scheduler
# only rank 0 need to pass the data back # only rank 0 need to pass the data back
@ -432,14 +422,6 @@ class rpcWorkerService(rpyc.Service):
return data.item() return data.item()
def __del__(self):
"""
profiling only, remove later
"""
del self.t_prepare
del self.t_exe
del self.t_sampler
def record(self): def record(self):
if self.profiling: if self.profiling:
file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json" file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json"

View File

@ -194,70 +194,3 @@ def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
""" """
else: else:
return ModelType.UNKNOWN return ModelType.UNKNOWN
"""
below just for profiling temporarily, will removed before merge
"""
import time
from contextlib import asynccontextmanager, contextmanager
@contextmanager
def timer(name=""):
# (@lry89757) will remove later
start_time = time.time()
try:
yield
finally:
end_time = time.time()
elapsed_time = end_time - start_time
print(f"{name} took {elapsed_time:.6f} seconds")
class Timer:
# (@lry89757) will remove later
def __init__(self, name=""):
print(f"init timer, {name}")
self.name = name
self.times = []
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
end_time = time.time()
elapsed_time = end_time - self.start_time
self.times.append(elapsed_time)
# print(f"{self.name} took {elapsed_time:.6f} seconds")
# self.print_info()
def print_info(self):
average_prefill_time = self.times[0]
print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds")
if len(self.times) > 1:
average_decoding_time = sum(self.times[1:]) / len(self.times[1:])
print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds")
def __del__(self):
if self.times:
average_prefill_time = self.times[0]
print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds")
if len(self.times) > 1:
average_decoding_time = sum(self.times[1:]) / len(self.times[1:])
print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds")
else:
print(f"{self.name} no timings recorded")
@asynccontextmanager
async def async_timer(name=""):
# (@lry89757) will remove later
start_time = time.time()
try:
yield
finally:
end_time = time.time()
elapsed_time = end_time - start_time
print(f"{name} took {elapsed_time:.6f} seconds")