[shardformer] support interleaved pipeline (#4448)

* support interleaved pipeline

* fix unit test

* remove virtual stage test in stage mgr

* add droped type hint and updated bwd
This commit is contained in:
LuGY
2023-08-16 19:29:03 +08:00
committed by GitHub
parent 26e29d58f0
commit a78daf6180
7 changed files with 642 additions and 109 deletions

View File

@@ -53,6 +53,62 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def recv_forward(self, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For 1F1B.
Args:
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For 1F1B.
Args:
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
def forward_step(self,
model: Module,
input_obj: Optional[dict],
@@ -171,11 +227,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_obj = self.comm.recv_forward()
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.comm.send_forward(output_obj)
self.send_forward(output_obj)
if not forward_only:
input_objs.append(input_obj)
@@ -185,7 +241,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_obj = self.comm.recv_forward()
input_obj = self.recv_forward()
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
@@ -193,15 +249,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
self.comm.send_forward(output_obj)
self.send_forward(output_obj)
if not last_iteration:
input_obj = self.comm.recv_forward()
input_obj = self.recv_forward()
else:
# TODO adjust here
self.comm.send_forward(output_obj)
output_obj_grad = self.comm.recv_backward()
self.send_forward(output_obj)
output_obj_grad = self.recv_backward()
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
@@ -216,8 +272,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if last_iteration:
input_obj = None
else:
input_obj = self.comm.recv_forward()
self.comm.send_backward(input_obj_grad)
input_obj = self.recv_forward()
self.send_backward(input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
@@ -225,9 +281,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
output_obj_grad = self.comm.recv_backward()
output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.comm.send_backward(input_obj_grad)
self.send_backward(input_obj_grad)
if outputs is not None:
outputs = merge_batch(outputs)