mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-16 23:16:56 +00:00
parent
f71e63b0f3
commit
f747d13040
@ -85,8 +85,6 @@ class CaiInferEngine:
|
||||
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
|
||||
|
||||
# TODO: support only tensor parallel inference
|
||||
assert pp_size > 1, "Not support only tensor parallel inference."
|
||||
self.pp_size = pp_size
|
||||
self.tp_size = tp_size
|
||||
|
||||
@ -102,23 +100,21 @@ class CaiInferEngine:
|
||||
# Init pg mesh
|
||||
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
|
||||
|
||||
stage_manager = None
|
||||
if pp_size > 1:
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
|
||||
self.cache_manager_list = [
|
||||
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
|
||||
for _ in range(micro_batch_buffer_size or pp_size)
|
||||
]
|
||||
self.mb_manager = MicroBatchManager(
|
||||
stage_manager.stage,
|
||||
micro_batch_size,
|
||||
micro_batch_buffer_size or pp_size,
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
self.cache_manager_list,
|
||||
)
|
||||
self.verbose = verbose
|
||||
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
|
||||
self.cache_manager_list = [
|
||||
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
|
||||
for _ in range(micro_batch_buffer_size or pp_size)
|
||||
]
|
||||
self.mb_manager = MicroBatchManager(
|
||||
stage_manager.stage,
|
||||
micro_batch_size,
|
||||
micro_batch_buffer_size or pp_size,
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
self.cache_manager_list,
|
||||
)
|
||||
self.verbose = verbose
|
||||
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
||||
|
||||
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
|
||||
|
||||
@ -146,7 +142,7 @@ class CaiInferEngine:
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=False,
|
||||
enable_tensor_parallelism=True if self.tp_size > 1 else False,
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
|
@ -103,4 +103,4 @@ class MemoryManager:
|
||||
self.available_size = len(self.mem_state)
|
||||
self.mem_state[:] = 1
|
||||
self.max_len_in_batch = 0
|
||||
self.logger.info("freed all space of memory manager")
|
||||
# self.logger.info("freed all space of memory manager")
|
||||
|
@ -93,9 +93,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||
"""
|
||||
model_inputs = {
|
||||
'infer_state': self.mb_manager.cur_descrption.infer_state
|
||||
}
|
||||
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
|
||||
return model_inputs
|
||||
|
||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||
@ -129,8 +127,8 @@ class GenerateSchedule(PipelineSchedule):
|
||||
|
||||
def _init_infer_state_action(self) -> None:
|
||||
"""
|
||||
This action is only for no first stage, to load batch and init infer_state.
|
||||
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
|
||||
This action is only for no first stage, to load batch and init infer_state.
|
||||
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
@ -145,19 +143,19 @@ class GenerateSchedule(PipelineSchedule):
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _gen_token_action(self, model: Module):
|
||||
"""
|
||||
This action is only for first stage
|
||||
This action is only for first stage
|
||||
1.do the forward with hidden_states to generate new tokens 2.step to update
|
||||
"""
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
|
||||
logits = model_forward(model, None, interval_inputs)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
@ -178,18 +176,18 @@ class GenerateSchedule(PipelineSchedule):
|
||||
new_token = self.action_interval_buffer.new_token
|
||||
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _body_encoding_action(self, model: Module):
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, None, interval_inputs)
|
||||
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
||||
"""
|
||||
@ -233,12 +231,73 @@ class GenerateSchedule(PipelineSchedule):
|
||||
|
||||
return actions
|
||||
|
||||
def _gen_one_stage_action(self, model: Module):
|
||||
"""
|
||||
In this function, it will generate a sequence action for current state, and do the action one by one.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be run.
|
||||
|
||||
Returns:
|
||||
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
|
||||
"""
|
||||
actions = []
|
||||
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
actions.append(partial(self._load_stage_action, model))
|
||||
elif self.mb_manager.cur_state is Status.GENERATE:
|
||||
actions.append(partial(self._gen_token_action, model))
|
||||
actions.append(partial(self._head_encoding_action, model))
|
||||
elif self.mb_manager.cur_state is Status.COOLDOWN:
|
||||
actions.append(partial(self._gen_token_action, model))
|
||||
|
||||
return actions
|
||||
|
||||
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||
if self.stage_manager.num_stages == 2:
|
||||
if self.stage_manager.num_stages == 1:
|
||||
return self.generate_step_one_stage(model, data_iter)
|
||||
elif self.stage_manager.num_stages == 2:
|
||||
return self.generate_step_p2p(model, data_iter)
|
||||
else:
|
||||
return self.generate_step_broadcast(model, data_iter)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||
"""
|
||||
Forward one step of the pipeline, when pipeline size is 1.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be run.
|
||||
data_iter (Iterable): Data iterator.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
output_sequence = []
|
||||
self.load_batch(data_iter)
|
||||
model.eval()
|
||||
self.comm_dtype = model.dtype
|
||||
|
||||
whole_timestamp = []
|
||||
|
||||
# run by round
|
||||
for _ in range(self.round):
|
||||
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None
|
||||
self.action_interval_buffer.clear()
|
||||
while self.mb_manager.is_micro_batch_done() is False:
|
||||
actions = self._gen_one_stage_action(model)
|
||||
for action in actions:
|
||||
action()
|
||||
self.mb_manager.next()
|
||||
# All microbatch in current round is DONE
|
||||
output_sequence.extend(self.mb_manager.export_new_tokens())
|
||||
|
||||
self.mb_manager.clear()
|
||||
if self.verbose:
|
||||
whole_timestamp.extend(self.timestamps)
|
||||
|
||||
return output_sequence, whole_timestamp
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||
"""
|
||||
@ -319,7 +378,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
# In GENERATE phase
|
||||
else:
|
||||
@ -330,18 +389,23 @@ class GenerateSchedule(PipelineSchedule):
|
||||
assert (
|
||||
hidden_states is not None
|
||||
), "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {
|
||||
"hidden_states": hidden_states["hidden_states"],
|
||||
"infer_state": self.mb_manager.cur_infer_state,
|
||||
}
|
||||
logits = model_forward(model, None, interval_inputs)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits['logits'])
|
||||
assert (
|
||||
"logits" in logits
|
||||
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits["logits"])
|
||||
self.mb_manager.step(new_token)
|
||||
# If the current micro batch is not DONE, go through blocks
|
||||
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
else:
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
@ -350,7 +414,10 @@ class GenerateSchedule(PipelineSchedule):
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {
|
||||
"hidden_states": hidden_states["hidden_states"],
|
||||
"infer_state": self.mb_manager.cur_infer_state,
|
||||
}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
# Current microbatch is not DONE, send hidden_state to next stage
|
||||
|
@ -65,6 +65,16 @@ def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize("tp_size", [2])
|
||||
@parameterize("pp_size", [1])
|
||||
@parameterize("max_output_len", [2])
|
||||
@parameterize("micro_batch_size", [1])
|
||||
@clear_cache_before_run()
|
||||
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_pipeline_inference(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_pipeline_inference_test()
|
||||
@ -75,6 +85,11 @@ def check_tp_pipeline_inference(rank, world_size, port):
|
||||
run_tp_pipeline_inference_test()
|
||||
|
||||
|
||||
def check_tp_inference(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_tp_inference_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@ -82,6 +97,7 @@ def check_tp_pipeline_inference(rank, world_size, port):
|
||||
def test_pipeline_inference():
|
||||
spawn(check_pipeline_inference, nprocs=2)
|
||||
spawn(check_tp_pipeline_inference, nprocs=4)
|
||||
spawn(check_tp_inference, nprocs=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
Loading…
Reference in New Issue
Block a user