mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-17 15:36:53 +00:00
parent
f71e63b0f3
commit
f747d13040
@ -85,8 +85,6 @@ class CaiInferEngine:
|
|||||||
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
|
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||||
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
|
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
|
||||||
|
|
||||||
# TODO: support only tensor parallel inference
|
|
||||||
assert pp_size > 1, "Not support only tensor parallel inference."
|
|
||||||
self.pp_size = pp_size
|
self.pp_size = pp_size
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
|
|
||||||
@ -102,23 +100,21 @@ class CaiInferEngine:
|
|||||||
# Init pg mesh
|
# Init pg mesh
|
||||||
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
|
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
|
||||||
|
|
||||||
stage_manager = None
|
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
|
||||||
if pp_size > 1:
|
self.cache_manager_list = [
|
||||||
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
|
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
|
||||||
self.cache_manager_list = [
|
for _ in range(micro_batch_buffer_size or pp_size)
|
||||||
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
|
]
|
||||||
for _ in range(micro_batch_buffer_size or pp_size)
|
self.mb_manager = MicroBatchManager(
|
||||||
]
|
stage_manager.stage,
|
||||||
self.mb_manager = MicroBatchManager(
|
micro_batch_size,
|
||||||
stage_manager.stage,
|
micro_batch_buffer_size or pp_size,
|
||||||
micro_batch_size,
|
max_input_len,
|
||||||
micro_batch_buffer_size or pp_size,
|
max_output_len,
|
||||||
max_input_len,
|
self.cache_manager_list,
|
||||||
max_output_len,
|
)
|
||||||
self.cache_manager_list,
|
self.verbose = verbose
|
||||||
)
|
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
||||||
self.verbose = verbose
|
|
||||||
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
|
||||||
|
|
||||||
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
|
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
|
||||||
|
|
||||||
@ -146,7 +142,7 @@ class CaiInferEngine:
|
|||||||
shardconfig = ShardConfig(
|
shardconfig = ShardConfig(
|
||||||
tensor_parallel_process_group=tp_group,
|
tensor_parallel_process_group=tp_group,
|
||||||
pipeline_stage_manager=stage_manager,
|
pipeline_stage_manager=stage_manager,
|
||||||
enable_tensor_parallelism=False,
|
enable_tensor_parallelism=True if self.tp_size > 1 else False,
|
||||||
enable_fused_normalization=False,
|
enable_fused_normalization=False,
|
||||||
enable_all_optimization=False,
|
enable_all_optimization=False,
|
||||||
enable_flash_attention=False,
|
enable_flash_attention=False,
|
||||||
|
@ -103,4 +103,4 @@ class MemoryManager:
|
|||||||
self.available_size = len(self.mem_state)
|
self.available_size = len(self.mem_state)
|
||||||
self.mem_state[:] = 1
|
self.mem_state[:] = 1
|
||||||
self.max_len_in_batch = 0
|
self.max_len_in_batch = 0
|
||||||
self.logger.info("freed all space of memory manager")
|
# self.logger.info("freed all space of memory manager")
|
||||||
|
@ -93,9 +93,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||||
"""
|
"""
|
||||||
model_inputs = {
|
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
|
||||||
'infer_state': self.mb_manager.cur_descrption.infer_state
|
|
||||||
}
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||||
@ -129,8 +127,8 @@ class GenerateSchedule(PipelineSchedule):
|
|||||||
|
|
||||||
def _init_infer_state_action(self) -> None:
|
def _init_infer_state_action(self) -> None:
|
||||||
"""
|
"""
|
||||||
This action is only for no first stage, to load batch and init infer_state.
|
This action is only for no first stage, to load batch and init infer_state.
|
||||||
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
|
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
|
||||||
"""
|
"""
|
||||||
inputs_dict = self.load_micro_batch()
|
inputs_dict = self.load_micro_batch()
|
||||||
self.mb_manager.add_descrption(inputs_dict)
|
self.mb_manager.add_descrption(inputs_dict)
|
||||||
@ -145,19 +143,19 @@ class GenerateSchedule(PipelineSchedule):
|
|||||||
if self.verbose and self.stage_manager.is_first_stage():
|
if self.verbose and self.stage_manager.is_first_stage():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||||
|
|
||||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||||
|
|
||||||
def _gen_token_action(self, model: Module):
|
def _gen_token_action(self, model: Module):
|
||||||
"""
|
"""
|
||||||
This action is only for first stage
|
This action is only for first stage
|
||||||
1.do the forward with hidden_states to generate new tokens 2.step to update
|
1.do the forward with hidden_states to generate new tokens 2.step to update
|
||||||
"""
|
"""
|
||||||
hidden_states = self.action_interval_buffer.hidden_states
|
hidden_states = self.action_interval_buffer.hidden_states
|
||||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||||
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
|
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
|
||||||
logits = model_forward(model, None, interval_inputs)
|
logits = model_forward(model, None, interval_inputs)
|
||||||
if self.verbose and self.stage_manager.is_first_stage():
|
if self.verbose and self.stage_manager.is_first_stage():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -178,18 +176,18 @@ class GenerateSchedule(PipelineSchedule):
|
|||||||
new_token = self.action_interval_buffer.new_token
|
new_token = self.action_interval_buffer.new_token
|
||||||
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
|
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
|
||||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||||
|
|
||||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||||
|
|
||||||
def _body_encoding_action(self, model: Module):
|
def _body_encoding_action(self, model: Module):
|
||||||
hidden_states = self.action_interval_buffer.hidden_states
|
hidden_states = self.action_interval_buffer.hidden_states
|
||||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||||
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
|
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
|
||||||
output_dict = model_forward(model, None, interval_inputs)
|
output_dict = model_forward(model, None, interval_inputs)
|
||||||
|
|
||||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||||
|
|
||||||
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -233,12 +231,73 @@ class GenerateSchedule(PipelineSchedule):
|
|||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
def _gen_one_stage_action(self, model: Module):
|
||||||
|
"""
|
||||||
|
In this function, it will generate a sequence action for current state, and do the action one by one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Module): Model to be run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
|
||||||
|
"""
|
||||||
|
actions = []
|
||||||
|
|
||||||
|
if self.mb_manager.cur_state is Status.PREFILL:
|
||||||
|
actions.append(partial(self._load_stage_action, model))
|
||||||
|
elif self.mb_manager.cur_state is Status.GENERATE:
|
||||||
|
actions.append(partial(self._gen_token_action, model))
|
||||||
|
actions.append(partial(self._head_encoding_action, model))
|
||||||
|
elif self.mb_manager.cur_state is Status.COOLDOWN:
|
||||||
|
actions.append(partial(self._gen_token_action, model))
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||||
if self.stage_manager.num_stages == 2:
|
if self.stage_manager.num_stages == 1:
|
||||||
|
return self.generate_step_one_stage(model, data_iter)
|
||||||
|
elif self.stage_manager.num_stages == 2:
|
||||||
return self.generate_step_p2p(model, data_iter)
|
return self.generate_step_p2p(model, data_iter)
|
||||||
else:
|
else:
|
||||||
return self.generate_step_broadcast(model, data_iter)
|
return self.generate_step_broadcast(model, data_iter)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||||
|
"""
|
||||||
|
Forward one step of the pipeline, when pipeline size is 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Module): Model to be run.
|
||||||
|
data_iter (Iterable): Data iterator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||||
|
"""
|
||||||
|
output_sequence = []
|
||||||
|
self.load_batch(data_iter)
|
||||||
|
model.eval()
|
||||||
|
self.comm_dtype = model.dtype
|
||||||
|
|
||||||
|
whole_timestamp = []
|
||||||
|
|
||||||
|
# run by round
|
||||||
|
for _ in range(self.round):
|
||||||
|
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None
|
||||||
|
self.action_interval_buffer.clear()
|
||||||
|
while self.mb_manager.is_micro_batch_done() is False:
|
||||||
|
actions = self._gen_one_stage_action(model)
|
||||||
|
for action in actions:
|
||||||
|
action()
|
||||||
|
self.mb_manager.next()
|
||||||
|
# All microbatch in current round is DONE
|
||||||
|
output_sequence.extend(self.mb_manager.export_new_tokens())
|
||||||
|
|
||||||
|
self.mb_manager.clear()
|
||||||
|
if self.verbose:
|
||||||
|
whole_timestamp.extend(self.timestamps)
|
||||||
|
|
||||||
|
return output_sequence, whole_timestamp
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||||
"""
|
"""
|
||||||
@ -319,7 +378,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||||
self.mb_manager.add_descrption(inputs_dict)
|
self.mb_manager.add_descrption(inputs_dict)
|
||||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||||
# In GENERATE phase
|
# In GENERATE phase
|
||||||
else:
|
else:
|
||||||
@ -330,18 +389,23 @@ class GenerateSchedule(PipelineSchedule):
|
|||||||
assert (
|
assert (
|
||||||
hidden_states is not None
|
hidden_states is not None
|
||||||
), "When first stage in GENERATE phase, the hidden states should not be None"
|
), "When first stage in GENERATE phase, the hidden states should not be None"
|
||||||
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
|
interval_inputs = {
|
||||||
|
"hidden_states": hidden_states["hidden_states"],
|
||||||
|
"infer_state": self.mb_manager.cur_infer_state,
|
||||||
|
}
|
||||||
logits = model_forward(model, None, interval_inputs)
|
logits = model_forward(model, None, interval_inputs)
|
||||||
if self.verbose and self.stage_manager.is_first_stage():
|
if self.verbose and self.stage_manager.is_first_stage():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
assert (
|
||||||
new_token = self._get_token_id(logits['logits'])
|
"logits" in logits
|
||||||
|
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||||
|
new_token = self._get_token_id(logits["logits"])
|
||||||
self.mb_manager.step(new_token)
|
self.mb_manager.step(new_token)
|
||||||
# If the current micro batch is not DONE, go through blocks
|
# If the current micro batch is not DONE, go through blocks
|
||||||
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
||||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||||
else:
|
else:
|
||||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||||
@ -350,7 +414,10 @@ class GenerateSchedule(PipelineSchedule):
|
|||||||
if self.mb_manager.cur_state is Status.PREFILL:
|
if self.mb_manager.cur_state is Status.PREFILL:
|
||||||
inputs_dict = self.load_micro_batch()
|
inputs_dict = self.load_micro_batch()
|
||||||
self.mb_manager.add_descrption(inputs_dict)
|
self.mb_manager.add_descrption(inputs_dict)
|
||||||
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
|
interval_inputs = {
|
||||||
|
"hidden_states": hidden_states["hidden_states"],
|
||||||
|
"infer_state": self.mb_manager.cur_infer_state,
|
||||||
|
}
|
||||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||||
|
|
||||||
# Current microbatch is not DONE, send hidden_state to next stage
|
# Current microbatch is not DONE, send hidden_state to next stage
|
||||||
|
@ -65,6 +65,16 @@ def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("tp_size", [2])
|
||||||
|
@parameterize("pp_size", [1])
|
||||||
|
@parameterize("max_output_len", [2])
|
||||||
|
@parameterize("micro_batch_size", [1])
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||||
|
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def check_pipeline_inference(rank, world_size, port):
|
def check_pipeline_inference(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
run_pipeline_inference_test()
|
run_pipeline_inference_test()
|
||||||
@ -75,6 +85,11 @@ def check_tp_pipeline_inference(rank, world_size, port):
|
|||||||
run_tp_pipeline_inference_test()
|
run_tp_pipeline_inference_test()
|
||||||
|
|
||||||
|
|
||||||
|
def check_tp_inference(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_tp_inference_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@ -82,6 +97,7 @@ def check_tp_pipeline_inference(rank, world_size, port):
|
|||||||
def test_pipeline_inference():
|
def test_pipeline_inference():
|
||||||
spawn(check_pipeline_inference, nprocs=2)
|
spawn(check_pipeline_inference, nprocs=2)
|
||||||
spawn(check_tp_pipeline_inference, nprocs=4)
|
spawn(check_tp_pipeline_inference, nprocs=4)
|
||||||
|
spawn(check_tp_inference, nprocs=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
Loading…
Reference in New Issue
Block a user