mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 02:57:20 +00:00
remove timer
This commit is contained in:
parent
6dcc127c42
commit
01ca9b8133
@ -1,7 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import concurrent
|
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import nullcontext
|
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
@ -17,7 +15,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from colossalai.inference.batch_bucket import RPCBatchBucket
|
from colossalai.inference.batch_bucket import RPCBatchBucket
|
||||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||||
from colossalai.inference.executor.rpc_worker import rpcWorkerService
|
from colossalai.inference.executor.rpc_worker import rpcWorkerService
|
||||||
from colossalai.inference.utils import Timer, find_available_ports
|
from colossalai.inference.utils import find_available_ports
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
@ -126,18 +124,8 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
self.loop = asyncio.new_event_loop()
|
self.loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(self.loop)
|
asyncio.set_event_loop(self.loop)
|
||||||
|
|
||||||
self.timer = False
|
|
||||||
self.t_prepare = Timer("[Timer] prepare the data 2") if self.timer else nullcontext()
|
|
||||||
self.t_exe = Timer("[Timer] execute rpc worker") if self.timer else nullcontext()
|
|
||||||
# self.t_sampler = Timer("[Timer] sampler time")
|
|
||||||
|
|
||||||
self.logger.info("engine init over ")
|
self.logger.info("engine init over ")
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self.timer:
|
|
||||||
del self.t_prepare
|
|
||||||
del self.t_exe
|
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
"""Verify the input args"""
|
"""Verify the input args"""
|
||||||
if not isinstance(self.inference_config, InferenceConfig):
|
if not isinstance(self.inference_config, InferenceConfig):
|
||||||
@ -313,34 +301,6 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
|
|
||||||
return ret[0]
|
return ret[0]
|
||||||
|
|
||||||
def step_(self, input_token_ids, input_meta_data: InputMetaData):
|
|
||||||
assert len(self.workers) == self.tp_size, "init workers first"
|
|
||||||
init_tasks = []
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.workers)) as executor:
|
|
||||||
for rank, worker in enumerate(self.workers):
|
|
||||||
if rank == 0:
|
|
||||||
init_tasks.append(
|
|
||||||
executor.submit(
|
|
||||||
worker.execute_model_forward,
|
|
||||||
pickle.dumps(input_token_ids),
|
|
||||||
pickle.dumps(input_meta_data.to_rpc_param()),
|
|
||||||
pickle.dumps(self.generation_config_dict),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
init_tasks.append(
|
|
||||||
executor.submit(
|
|
||||||
worker.execute_model_forward,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
concurrent.futures.wait(init_tasks)
|
|
||||||
results = [future.result() for future in init_tasks]
|
|
||||||
return results[0]
|
|
||||||
|
|
||||||
def step(self) -> List[str]:
|
def step(self) -> List[str]:
|
||||||
with self.t_prepare:
|
with self.t_prepare:
|
||||||
batch = self.request_handler.schedule()
|
batch = self.request_handler.schedule()
|
||||||
@ -350,8 +310,6 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
with self.t_exe:
|
with self.t_exe:
|
||||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||||
next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data))
|
next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data))
|
||||||
# with self.t_exe:
|
|
||||||
# next_tokens = self.step_(input_token_ids, input_meta_data)
|
|
||||||
|
|
||||||
# update the request_handler
|
# update the request_handler
|
||||||
self.request_handler.append_next_tokens(next_tokens)
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
@ -360,7 +318,7 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
|
|
||||||
def kill_workers(self):
|
def kill_workers(self):
|
||||||
"""
|
"""
|
||||||
I don't find a good way to implicit invoke self.kill_workers
|
NOTE(@lry89757) Don't find a good way to implicit invoke self.kill_workers
|
||||||
"""
|
"""
|
||||||
assert len(self.workers) != 0
|
assert len(self.workers) != 0
|
||||||
for proc in self.worker_processes:
|
for proc in self.worker_processes:
|
||||||
|
@ -55,13 +55,6 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address)
|
colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address)
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
|
|
||||||
# profiling only, remove later
|
|
||||||
self.timing = False
|
|
||||||
|
|
||||||
self.t_prepare = Timer("[Timer] prepare the data 1") if self.timing else nullcontext()
|
|
||||||
self.t_exe = Timer("[Timer] execute the model forward") if self.timing else nullcontext()
|
|
||||||
self.t_sampler = Timer("[Timer] sampler time") if self.timing else nullcontext()
|
|
||||||
|
|
||||||
self.profiling = False
|
self.profiling = False
|
||||||
self.profiler = (
|
self.profiler = (
|
||||||
torch.profiler.profile(
|
torch.profiler.profile(
|
||||||
@ -133,7 +126,6 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
):
|
):
|
||||||
with self.profiler:
|
with self.profiler:
|
||||||
# prepare the data for model forward
|
# prepare the data for model forward
|
||||||
with self.t_prepare:
|
|
||||||
input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers(
|
input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers(
|
||||||
input_token_ids_param=input_token_ids_param,
|
input_token_ids_param=input_token_ids_param,
|
||||||
input_meta_data_param=input_meta_data_param,
|
input_meta_data_param=input_meta_data_param,
|
||||||
@ -146,7 +138,6 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
n_tokens = input_meta_data.batch_size
|
n_tokens = input_meta_data.batch_size
|
||||||
|
|
||||||
# execute the model
|
# execute the model
|
||||||
with self.t_exe:
|
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
input_token_ids,
|
input_token_ids,
|
||||||
self.output_tensor[:n_tokens],
|
self.output_tensor[:n_tokens],
|
||||||
@ -161,7 +152,6 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
self.record()
|
self.record()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
with self.t_sampler:
|
|
||||||
# sampler
|
# sampler
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
@ -432,14 +422,6 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
|
|
||||||
return data.item()
|
return data.item()
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""
|
|
||||||
profiling only, remove later
|
|
||||||
"""
|
|
||||||
del self.t_prepare
|
|
||||||
del self.t_exe
|
|
||||||
del self.t_sampler
|
|
||||||
|
|
||||||
def record(self):
|
def record(self):
|
||||||
if self.profiling:
|
if self.profiling:
|
||||||
file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json"
|
file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json"
|
||||||
|
@ -194,70 +194,3 @@ def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
|
|||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
return ModelType.UNKNOWN
|
return ModelType.UNKNOWN
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
below just for profiling temporarily, will removed before merge
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timer(name=""):
|
|
||||||
# (@lry89757) will remove later
|
|
||||||
start_time = time.time()
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
end_time = time.time()
|
|
||||||
elapsed_time = end_time - start_time
|
|
||||||
print(f"{name} took {elapsed_time:.6f} seconds")
|
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
|
||||||
# (@lry89757) will remove later
|
|
||||||
def __init__(self, name=""):
|
|
||||||
print(f"init timer, {name}")
|
|
||||||
self.name = name
|
|
||||||
self.times = []
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.start_time = time.time()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
end_time = time.time()
|
|
||||||
elapsed_time = end_time - self.start_time
|
|
||||||
self.times.append(elapsed_time)
|
|
||||||
# print(f"{self.name} took {elapsed_time:.6f} seconds")
|
|
||||||
# self.print_info()
|
|
||||||
|
|
||||||
def print_info(self):
|
|
||||||
average_prefill_time = self.times[0]
|
|
||||||
print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds")
|
|
||||||
if len(self.times) > 1:
|
|
||||||
average_decoding_time = sum(self.times[1:]) / len(self.times[1:])
|
|
||||||
print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds")
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self.times:
|
|
||||||
average_prefill_time = self.times[0]
|
|
||||||
print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds")
|
|
||||||
if len(self.times) > 1:
|
|
||||||
average_decoding_time = sum(self.times[1:]) / len(self.times[1:])
|
|
||||||
print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds")
|
|
||||||
else:
|
|
||||||
print(f"{self.name} no timings recorded")
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def async_timer(name=""):
|
|
||||||
# (@lry89757) will remove later
|
|
||||||
start_time = time.time()
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
end_time = time.time()
|
|
||||||
elapsed_time = end_time - start_time
|
|
||||||
print(f"{name} took {elapsed_time:.6f} seconds")
|
|
||||||
|
Loading…
Reference in New Issue
Block a user