mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[shardformer] support interleaved pipeline (#4448)
* support interleaved pipeline * fix unit test * remove virtual stage test in stage mgr * add droped type hint and updated bwd
This commit is contained in:
@@ -53,6 +53,62 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
|
||||
|
||||
def recv_forward(self, prev_rank: int = None) -> Any:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
prev_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor = self.comm.recv_forward(prev_rank)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
next_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_object, next_rank)
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
|
||||
def forward_step(self,
|
||||
model: Module,
|
||||
input_obj: Optional[dict],
|
||||
@@ -171,11 +227,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = self.comm.recv_forward()
|
||||
input_obj = self.recv_forward()
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
self.comm.send_forward(output_obj)
|
||||
self.send_forward(output_obj)
|
||||
|
||||
if not forward_only:
|
||||
input_objs.append(input_obj)
|
||||
@@ -185,7 +241,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
input_obj = self.comm.recv_forward()
|
||||
input_obj = self.recv_forward()
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
@@ -193,15 +249,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
self.comm.send_forward(output_obj)
|
||||
self.send_forward(output_obj)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.comm.recv_forward()
|
||||
input_obj = self.recv_forward()
|
||||
|
||||
else:
|
||||
# TODO adjust here
|
||||
self.comm.send_forward(output_obj)
|
||||
output_obj_grad = self.comm.recv_backward()
|
||||
self.send_forward(output_obj)
|
||||
output_obj_grad = self.recv_backward()
|
||||
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
@@ -216,8 +272,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
else:
|
||||
input_obj = self.comm.recv_forward()
|
||||
self.comm.send_backward(input_obj_grad)
|
||||
input_obj = self.recv_forward()
|
||||
self.send_backward(input_obj_grad)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
@@ -225,9 +281,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
output_obj_grad = self.comm.recv_backward()
|
||||
output_obj_grad = self.recv_backward()
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.comm.send_backward(input_obj_grad)
|
||||
self.send_backward(input_obj_grad)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
|
Reference in New Issue
Block a user