mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[Pipeline Inference] Sync pipeline inference branch to main (#4820)
* [pipeline inference] pipeline inference (#4492) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * fix CI * add cache clear * fix code error * fix typo * [Pipeline inference] Modify to tieweight (#4599) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * modify the way of saving newtokens * modify to tieweight * modify test * remove unused file * solve review * add docstring * [Pipeline inference] support llama pipeline inference (#4647) * support llama pipeline inference * remove tie weight operation * [pipeline inference] Fix the blocking of communication when ppsize is 2 (#4708) * add benchmark verbose * fix export tokens * fix benchmark verbose * add P2POp style to do p2p communication * modify schedule as p2p type when ppsize is 2 * remove unused code and add docstring * [Pipeline inference] Refactor code, add docsting, fix bug (#4790) * add benchmark script * update argparse * fix fp16 load * refactor code style * add docstring * polish code * fix test bug * [Pipeline inference] Add pipeline inference docs (#4817) * add readme doc * add a ico * Add performance * update table of contents * refactor code (#4873)
This commit is contained in:
@@ -160,6 +160,86 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||
return object_list[0]
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
tensor_send_next: torch.Tensor,
|
||||
recv_prev: bool,
|
||||
peer: int,
|
||||
group: ProcessGroup,
|
||||
comm_dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
"""
|
||||
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
|
||||
|
||||
Agrs:
|
||||
tensor_send_next (torch.Tensor): tensor to be sent to next stage
|
||||
recv_prev (bool): whether to receive tensor from previous stage
|
||||
peer (int): rank of the peer
|
||||
group (ProcessGroup): process group
|
||||
comm_dtype (torch.dtype): dtype of the tensor to be sent
|
||||
|
||||
Returns:
|
||||
torch.Tensor: tensor received from previous stage
|
||||
"""
|
||||
# send and recv shape
|
||||
send_next_shape = None
|
||||
recv_prev_shape = None
|
||||
|
||||
if tensor_send_next is not None:
|
||||
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
if recv_prev:
|
||||
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
|
||||
ops = []
|
||||
if send_next_shape is not None:
|
||||
send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group)
|
||||
ops.append(send_next_op)
|
||||
if recv_prev_shape is not None:
|
||||
recv_prev_op = dist.P2POp(
|
||||
dist.irecv,
|
||||
recv_prev_shape,
|
||||
peer=peer,
|
||||
group=group,
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
if recv_prev_shape is not None:
|
||||
recv_prev_shape = recv_prev_shape.tolist()
|
||||
|
||||
# send and recv data
|
||||
tensor_recv_prev = None
|
||||
if recv_prev:
|
||||
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
|
||||
|
||||
ops = []
|
||||
if tensor_send_next is not None:
|
||||
send_next_op = dist.P2POp(
|
||||
dist.isend,
|
||||
tensor_send_next,
|
||||
peer=peer,
|
||||
group=group,
|
||||
)
|
||||
ops.append(send_next_op)
|
||||
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_op = dist.P2POp(
|
||||
dist.irecv,
|
||||
tensor_recv_prev,
|
||||
peer=peer,
|
||||
group=group,
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
return tensor_recv_prev
|
||||
|
||||
|
||||
class PipelineP2PCommunication:
|
||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
@@ -221,3 +301,17 @@ class PipelineP2PCommunication:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
||||
|
||||
def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None:
|
||||
"""
|
||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if peer is None:
|
||||
peer = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype)
|
||||
return recv_tensor
|
||||
|
Reference in New Issue
Block a user