mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 04:18:55 +00:00
tmp save for profiling
This commit is contained in:
parent
55a5dd9dcd
commit
509f3a62ab
@ -526,6 +526,7 @@ class BatchBucket:
|
|||||||
class RPCBatchBucket(BatchBucket):
|
class RPCBatchBucket(BatchBucket):
|
||||||
def __init__(self, *args, **argv):
|
def __init__(self, *args, **argv):
|
||||||
self.is_rpc = True
|
self.is_rpc = True
|
||||||
|
self.device = "cpu"
|
||||||
super().__init__(*args, **argv)
|
super().__init__(*args, **argv)
|
||||||
|
|
||||||
# For compatibility
|
# For compatibility
|
||||||
|
@ -87,12 +87,14 @@ class InputMetaData(RPC_PARAM):
|
|||||||
|
|
||||||
def to_rpc_param(self) -> Dict[str, any]:
|
def to_rpc_param(self) -> Dict[str, any]:
|
||||||
return {
|
return {
|
||||||
"block_tables": self.block_tables.tolist()
|
"block_tables": self.block_tables,
|
||||||
if isinstance(self.block_tables, torch.Tensor)
|
# "block_tables": self.block_tables.tolist()
|
||||||
else self.block_tables,
|
# if isinstance(self.block_tables, torch.Tensor)
|
||||||
"sequence_lengths": self.sequence_lengths.tolist()
|
# else self.block_tables,
|
||||||
if isinstance(self.block_tables, torch.Tensor)
|
"sequence_lengths": self.sequence_lengths,
|
||||||
else self.sequence_lengths,
|
# "sequence_lengths": self.sequence_lengths.tolist()
|
||||||
|
# if isinstance(self.block_tables, torch.Tensor)
|
||||||
|
# else self.sequence_lengths,
|
||||||
"batch_size": self.batch_size,
|
"batch_size": self.batch_size,
|
||||||
"is_prompts": self.is_prompts,
|
"is_prompts": self.is_prompts,
|
||||||
"use_cuda_kernel": self.use_cuda_kernel,
|
"use_cuda_kernel": self.use_cuda_kernel,
|
||||||
@ -114,17 +116,14 @@ class InputMetaData(RPC_PARAM):
|
|||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
|
|
||||||
dtype = getattr(torch, rpc_dict["dtype"])
|
dtype = getattr(torch, rpc_dict["dtype"])
|
||||||
|
device = get_accelerator().get_current_device()
|
||||||
return InputMetaData(
|
return InputMetaData(
|
||||||
block_tables=torch.tensor(
|
block_tables=torch.tensor(rpc_dict["block_tables"], dtype=torch.int, device=device)
|
||||||
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
|
|
||||||
)
|
|
||||||
if isinstance(rpc_dict["block_tables"], list)
|
if isinstance(rpc_dict["block_tables"], list)
|
||||||
else rpc_dict["block_tables"],
|
else rpc_dict["block_tables"].to(device),
|
||||||
sequence_lengths=torch.tensor(
|
sequence_lengths=torch.tensor(rpc_dict["sequence_lengths"], dtype=torch.int, device=device)
|
||||||
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
|
|
||||||
)
|
|
||||||
if isinstance(rpc_dict["sequence_lengths"], list)
|
if isinstance(rpc_dict["sequence_lengths"], list)
|
||||||
else rpc_dict["sequence_lengths"],
|
else rpc_dict["sequence_lengths"].to(device),
|
||||||
batch_size=rpc_dict["batch_size"],
|
batch_size=rpc_dict["batch_size"],
|
||||||
is_prompts=rpc_dict["is_prompts"],
|
is_prompts=rpc_dict["is_prompts"],
|
||||||
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
|
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ from colossalai.inference.modeling.policy import model_policy_map
|
|||||||
from colossalai.inference.sampler import search_tokens
|
from colossalai.inference.sampler import search_tokens
|
||||||
from colossalai.inference.spec import Drafter, GlideInput
|
from colossalai.inference.spec import Drafter, GlideInput
|
||||||
from colossalai.inference.struct import Sequence
|
from colossalai.inference.struct import Sequence
|
||||||
from colossalai.inference.utils import get_model_size
|
from colossalai.inference.utils import Timer, get_model_size
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
@ -103,6 +104,30 @@ class InferenceEngine:
|
|||||||
self.use_glide = False
|
self.use_glide = False
|
||||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||||
|
|
||||||
|
# profiling only, remove later
|
||||||
|
self.timing = False
|
||||||
|
|
||||||
|
self.t_prepare = Timer("[Timer] prepare the data 1") if self.timing else nullcontext()
|
||||||
|
self.t_exe = Timer("[Timer] execute the model forward") if self.timing else nullcontext()
|
||||||
|
self.t_sampler = Timer("[Timer] sampler time") if self.timing else nullcontext()
|
||||||
|
|
||||||
|
self.profiling = False
|
||||||
|
self.profiler = (
|
||||||
|
torch.profiler.profile(
|
||||||
|
record_shapes=True,
|
||||||
|
with_stack=True,
|
||||||
|
with_modules=True,
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
# schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
|
||||||
|
# on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode),
|
||||||
|
)
|
||||||
|
if self.profiling
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
|
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
|
||||||
@ -517,6 +542,7 @@ class InferenceEngine:
|
|||||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||||
return_token_ids: bool = False,
|
return_token_ids: bool = False,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
step_list: Optional[List[int]] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Executing the inference step.
|
Executing the inference step.
|
||||||
@ -559,7 +585,11 @@ class InferenceEngine:
|
|||||||
output_seqs_list += self.steps_spec_dec()
|
output_seqs_list += self.steps_spec_dec()
|
||||||
else:
|
else:
|
||||||
while self.request_handler.check_unfinished_seqs():
|
while self.request_handler.check_unfinished_seqs():
|
||||||
|
a = time.perf_counter()
|
||||||
output_seqs_list += self.step()
|
output_seqs_list += self.step()
|
||||||
|
b = time.perf_counter()
|
||||||
|
if isinstance(step_list, list):
|
||||||
|
step_list.append(b - a)
|
||||||
|
|
||||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||||
|
|
||||||
@ -574,6 +604,19 @@ class InferenceEngine:
|
|||||||
else:
|
else:
|
||||||
return output_str
|
return output_str
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.timing:
|
||||||
|
del self.t_prepare
|
||||||
|
del self.t_exe
|
||||||
|
del self.t_sampler
|
||||||
|
self.record()
|
||||||
|
|
||||||
|
def record(self):
|
||||||
|
if self.profiling:
|
||||||
|
file = "/home/lurunyu/projects/ColossalAI/test_trace_non_rpc.json"
|
||||||
|
self.profiler.export_chrome_trace(file)
|
||||||
|
self.logger.info(f"trace has been saved into {file}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_prompt_template(self) -> bool:
|
def has_prompt_template(self) -> bool:
|
||||||
""" """
|
""" """
|
||||||
@ -741,6 +784,8 @@ class InferenceEngine:
|
|||||||
List[str]: Decoded finished sequences generated by one step.
|
List[str]: Decoded finished sequences generated by one step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
with self.profiler:
|
||||||
|
with self.t_prepare:
|
||||||
batch = self.request_handler.schedule()
|
batch = self.request_handler.schedule()
|
||||||
|
|
||||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||||
@ -750,12 +795,18 @@ class InferenceEngine:
|
|||||||
else:
|
else:
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
|
|
||||||
|
with self.t_exe:
|
||||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
|
||||||
|
with self.t_sampler:
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
next_tokens = search_tokens(
|
next_tokens = search_tokens(
|
||||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
self.generation_config,
|
||||||
|
logits,
|
||||||
|
input_meta_data.is_prompts,
|
||||||
|
batch_token_ids=input_meta_data.batch_token_ids,
|
||||||
)
|
)
|
||||||
self.request_handler.append_next_tokens(next_tokens)
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import concurrent
|
||||||
|
import pickle
|
||||||
|
from contextlib import nullcontext
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
@ -14,7 +17,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from colossalai.inference.batch_bucket import RPCBatchBucket
|
from colossalai.inference.batch_bucket import RPCBatchBucket
|
||||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||||
from colossalai.inference.executor.rpc_worker import rpcWorkerService
|
from colossalai.inference.executor.rpc_worker import rpcWorkerService
|
||||||
from colossalai.inference.utils import find_available_ports
|
from colossalai.inference.utils import Timer, find_available_ports
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
@ -119,8 +122,21 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
self.counter = count()
|
self.counter = count()
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
|
||||||
|
self.timer = False
|
||||||
|
self.t_prepare = Timer("[Timer] prepare the data 2") if self.timer else nullcontext()
|
||||||
|
self.t_exe = Timer("[Timer] execute rpc worker") if self.timer else nullcontext()
|
||||||
|
# self.t_sampler = Timer("[Timer] sampler time")
|
||||||
|
|
||||||
self.logger.info("engine init over ")
|
self.logger.info("engine init over ")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.timer:
|
||||||
|
del self.t_prepare
|
||||||
|
del self.t_exe
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
"""Verify the input args"""
|
"""Verify the input args"""
|
||||||
if not isinstance(self.inference_config, InferenceConfig):
|
if not isinstance(self.inference_config, InferenceConfig):
|
||||||
@ -268,7 +284,7 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
assert async_res.ready
|
assert async_res.ready
|
||||||
return async_res.value
|
return async_res.value
|
||||||
|
|
||||||
async def step_(self, input_token_ids, input_meta_data: InputMetaData):
|
async def step_async(self, input_token_ids, input_meta_data: InputMetaData):
|
||||||
assert len(self.workers) == self.tp_size, "init workers first"
|
assert len(self.workers) == self.tp_size, "init workers first"
|
||||||
|
|
||||||
init_tasks = []
|
init_tasks = []
|
||||||
@ -277,9 +293,9 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
init_tasks.append(
|
init_tasks.append(
|
||||||
self.async_parallel_forward(
|
self.async_parallel_forward(
|
||||||
async_forward,
|
async_forward,
|
||||||
input_token_ids,
|
pickle.dumps(input_token_ids),
|
||||||
input_meta_data.to_rpc_param(),
|
pickle.dumps(input_meta_data.to_rpc_param()),
|
||||||
self.generation_config_dict,
|
pickle.dumps(self.generation_config_dict),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -296,12 +312,45 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
|
|
||||||
return ret[0]
|
return ret[0]
|
||||||
|
|
||||||
|
def step_(self, input_token_ids, input_meta_data: InputMetaData):
|
||||||
|
assert len(self.workers) == self.tp_size, "init workers first"
|
||||||
|
init_tasks = []
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.workers)) as executor:
|
||||||
|
for rank, worker in enumerate(self.workers):
|
||||||
|
if rank == 0:
|
||||||
|
init_tasks.append(
|
||||||
|
executor.submit(
|
||||||
|
worker.execute_model_forward,
|
||||||
|
pickle.dumps(input_token_ids),
|
||||||
|
pickle.dumps(input_meta_data.to_rpc_param()),
|
||||||
|
pickle.dumps(self.generation_config_dict),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
init_tasks.append(
|
||||||
|
executor.submit(
|
||||||
|
worker.execute_model_forward,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
concurrent.futures.wait(init_tasks)
|
||||||
|
results = [future.result() for future in init_tasks]
|
||||||
|
return results[0]
|
||||||
|
|
||||||
def step(self) -> List[str]:
|
def step(self) -> List[str]:
|
||||||
|
with self.t_prepare:
|
||||||
batch = self.request_handler.schedule()
|
batch = self.request_handler.schedule()
|
||||||
|
|
||||||
input_token_ids, input_meta_data = self.prepare_input(batch)
|
input_token_ids, input_meta_data = self.prepare_input(batch)
|
||||||
|
|
||||||
|
with self.t_exe:
|
||||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||||
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data))
|
next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data))
|
||||||
|
# with self.t_exe:
|
||||||
|
# next_tokens = self.step_(input_token_ids, input_meta_data)
|
||||||
|
|
||||||
# update the request_handler
|
# update the request_handler
|
||||||
self.request_handler.append_next_tokens(next_tokens)
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import pickle
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import rpyc
|
import rpyc
|
||||||
@ -54,9 +56,29 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
|
|
||||||
# profiling only, remove later
|
# profiling only, remove later
|
||||||
self.t_prepare = Timer("[Timer] prepare the data")
|
self.timing = False
|
||||||
self.t_exe = Timer("[Timer] execute the model forward")
|
|
||||||
self.t_sampler = Timer("[Timer] sampler time")
|
self.t_prepare = Timer("[Timer] prepare the data 1") if self.timing else nullcontext()
|
||||||
|
self.t_exe = Timer("[Timer] execute the model forward") if self.timing else nullcontext()
|
||||||
|
self.t_sampler = Timer("[Timer] sampler time") if self.timing else nullcontext()
|
||||||
|
|
||||||
|
self.profiling = False
|
||||||
|
self.profiler = (
|
||||||
|
torch.profiler.profile(
|
||||||
|
record_shapes=True,
|
||||||
|
with_stack=True,
|
||||||
|
with_modules=True,
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
# schedule=torch.profiler.schedule(wait=0, repeat=1, active=1),
|
||||||
|
# on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode),
|
||||||
|
)
|
||||||
|
if self.profiling
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"init process group done for rank {rank}")
|
logger.info(f"init process group done for rank {rank}")
|
||||||
|
|
||||||
def exposed_init_model(
|
def exposed_init_model(
|
||||||
@ -109,6 +131,7 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
input_meta_data_param: Optional[dict] = None,
|
input_meta_data_param: Optional[dict] = None,
|
||||||
generation_config_param: Optional[dict] = None,
|
generation_config_param: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
with self.profiler:
|
||||||
# prepare the data for model forward
|
# prepare the data for model forward
|
||||||
with self.t_prepare:
|
with self.t_prepare:
|
||||||
input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers(
|
input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers(
|
||||||
@ -132,6 +155,11 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
self.v_cache,
|
self.v_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.profiling:
|
||||||
|
self.profiler.step()
|
||||||
|
|
||||||
|
self.record()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
with self.t_sampler:
|
with self.t_sampler:
|
||||||
# sampler
|
# sampler
|
||||||
@ -191,6 +219,10 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
generation_config_param: Optional[dict] = None,
|
generation_config_param: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
input_token_ids_param = pickle.loads(input_token_ids_param)
|
||||||
|
input_meta_data_param = pickle.loads(input_meta_data_param)
|
||||||
|
generation_config_param = pickle.loads(generation_config_param)
|
||||||
|
|
||||||
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
||||||
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
||||||
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
|
input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device)
|
||||||
@ -199,7 +231,7 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
if dist.get_world_size() > 1:
|
if dist.get_world_size() > 1:
|
||||||
broadcast_list = {}
|
broadcast_list = {}
|
||||||
for k, v in input_meta_data_param.items():
|
for k, v in input_meta_data_param.items():
|
||||||
if not isinstance(v, List):
|
if not isinstance(v, torch.Tensor):
|
||||||
broadcast_list[k] = v
|
broadcast_list[k] = v
|
||||||
|
|
||||||
# Pass the tensor shape and type in advance for
|
# Pass the tensor shape and type in advance for
|
||||||
@ -248,7 +280,7 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
async3 = torch.distributed.broadcast(input_token_ids, src=0, async_op=True)
|
async3 = torch.distributed.broadcast(input_token_ids, src=0, async_op=True)
|
||||||
|
|
||||||
input_meta_data_param["sequence_lengths"] = sequence_lengths
|
input_meta_data_param["sequence_lengths"] = sequence_lengths
|
||||||
input_meta_data_param["blocktables"] = blocktables
|
input_meta_data_param["block_tables"] = blocktables
|
||||||
|
|
||||||
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
||||||
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
||||||
@ -257,9 +289,6 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
async2.wait()
|
async2.wait()
|
||||||
async3.wait()
|
async3.wait()
|
||||||
|
|
||||||
input_meta_data.block_tables = blocktables
|
|
||||||
input_meta_data.sequence_lengths = sequence_lengths
|
|
||||||
|
|
||||||
return input_token_ids, input_meta_data, generation_config
|
return input_token_ids, input_meta_data, generation_config
|
||||||
|
|
||||||
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
||||||
@ -408,3 +437,9 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
del self.t_prepare
|
del self.t_prepare
|
||||||
del self.t_exe
|
del self.t_exe
|
||||||
del self.t_sampler
|
del self.t_sampler
|
||||||
|
|
||||||
|
def record(self):
|
||||||
|
if self.profiling:
|
||||||
|
file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json"
|
||||||
|
self.profiler.export_chrome_trace(file)
|
||||||
|
logger.info(f"trace has been saved into {file}")
|
||||||
|
@ -149,8 +149,8 @@ class Timer:
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - self.start_time
|
elapsed_time = end_time - self.start_time
|
||||||
self.times.append(elapsed_time)
|
self.times.append(elapsed_time)
|
||||||
print(f"{self.name} took {elapsed_time:.6f} seconds")
|
# print(f"{self.name} took {elapsed_time:.6f} seconds")
|
||||||
self.print_info()
|
# self.print_info()
|
||||||
|
|
||||||
def print_info(self):
|
def print_info(self):
|
||||||
average_prefill_time = self.times[0]
|
average_prefill_time = self.times[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user