mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 19:48:23 +00:00
[feat] add comments for ZBV func;
This commit is contained in:
parent
f1c1a87246
commit
1b4bb2beeb
@ -40,9 +40,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self.num_microbatch = num_microbatch
|
self.num_microbatch = num_microbatch
|
||||||
self.collect_non_loss_data = None
|
self.collect_non_loss_data = None
|
||||||
self.forward_only = None
|
self.forward_only = None
|
||||||
|
|
||||||
self.schedules = schedule
|
self.schedules = schedule
|
||||||
self.it = 0 # curr iteration
|
# TODO: optim post valid
|
||||||
self.do_post_validation = False
|
self.do_post_validation = False
|
||||||
self.is_first_run = True
|
self.is_first_run = True
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
@ -69,16 +68,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self.input_tensors = [[], []]
|
self.input_tensors = [[], []]
|
||||||
self.output_tensors = [[], []]
|
self.output_tensors = [[], []]
|
||||||
|
|
||||||
# y & dy buffer for schedule b
|
# y & dy buffer for schedule w
|
||||||
self.output_tensors_dw = [[], []]
|
self.output_tensors_dw = [[], []]
|
||||||
self.output_tensors_grad_dw = [[], []]
|
self.output_tensors_grad_dw = [[], []]
|
||||||
|
|
||||||
|
# buffer for communication
|
||||||
self.send_forward_buffer = [[], []]
|
self.send_forward_buffer = [[], []]
|
||||||
self.recv_forward_buffer = [[], []]
|
self.recv_forward_buffer = [[], []]
|
||||||
self.send_backward_buffer = [[], []]
|
self.send_backward_buffer = [[], []]
|
||||||
self.recv_backward_buffer = [[], []]
|
self.recv_backward_buffer = [[], []]
|
||||||
self.forward_data_store = []
|
|
||||||
|
# y buffer for local send fwd
|
||||||
self.local_send_forward_buffer = []
|
self.local_send_forward_buffer = []
|
||||||
|
# dy buffer for local send bwd
|
||||||
self.local_send_backward_buffer = []
|
self.local_send_backward_buffer = []
|
||||||
|
|
||||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||||
@ -263,7 +265,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_chunk_id (int): The current model chunk idx.
|
model_chunk_id (int): The current model chunk idx.
|
||||||
output_object (Any): Object to be sent.
|
|
||||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -313,7 +314,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_chunk_id (int): The current model chunk idx.
|
model_chunk_id (int): The current model chunk idx.
|
||||||
input_object (Any): Object to be sent.
|
|
||||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -371,9 +371,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
) -> Union[torch.Tensor, dict]:
|
) -> Union[torch.Tensor, dict]:
|
||||||
"""Forward one step of the pipeline
|
"""Forward one step of the pipeline
|
||||||
Args:
|
Args:
|
||||||
model (ModuleList or Module): Model Chunk to be run
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||||
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
|
model_chunk_id (int): The current model chunk idx;
|
||||||
criterion (Callable): Criterion to calculate loss.
|
input_obj (Optional[dict]): x;
|
||||||
|
criterion (Callable): loss function;
|
||||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||||
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
||||||
|
|
||||||
@ -410,16 +411,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
output_obj: Union[dict, torch.Tensor],
|
output_obj: Union[dict, torch.Tensor],
|
||||||
output_obj_grad: Optional[dict],
|
output_obj_grad: Optional[dict],
|
||||||
) -> Optional[dict]:
|
) -> Optional[dict]:
|
||||||
"""Backward one step of the pipeline
|
"""Backward dx step of the pipeline; we calculate "dx = w*dy" here;
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||||
|
model_chunk_id (int): The current model chunk idx;
|
||||||
optimizer (OptimizerWrapper): Optimizer to update the model
|
optimizer (OptimizerWrapper): Optimizer to update the model
|
||||||
input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
|
input_obj (Optional[dict]): x.
|
||||||
output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
|
output_obj (Union[dict, torch.Tensor]): y.
|
||||||
output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
|
output_obj_grad (dict): dy.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
|
Optional[dict]: dx.
|
||||||
"""
|
"""
|
||||||
# calculate bwd b step ; only dx = w*dy;
|
# calculate bwd b step ; only dx = w*dy;
|
||||||
|
|
||||||
@ -451,10 +454,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
# input_obj: Optional[dict],
|
|
||||||
output_obj: Union[dict, torch.Tensor],
|
output_obj: Union[dict, torch.Tensor],
|
||||||
output_obj_grad: Optional[dict],
|
output_obj_grad: Optional[dict],
|
||||||
):
|
):
|
||||||
|
"""Backward dw step of the pipeline; we calculate "dw = x*dy" here;
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||||
|
model_chunk_id (int): The current model chunk idx;
|
||||||
|
optimizer (OptimizerWrapper): Optimizer to update the model
|
||||||
|
output_obj (Union[dict, torch.Tensor]): y.
|
||||||
|
output_obj_grad (dict): dy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Nothing need to return; we only calculate dw then update w;
|
||||||
|
"""
|
||||||
# calculate bwd w step ; only dw = x*dy;
|
# calculate bwd w step ; only dw = x*dy;
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
torch.autograd.backward(
|
torch.autograd.backward(
|
||||||
@ -481,6 +495,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
accum_loss: Optional[torch.Tensor] = None,
|
accum_loss: Optional[torch.Tensor] = None,
|
||||||
outputs: Optional[List[Any]] = None,
|
outputs: Optional[List[Any]] = None,
|
||||||
):
|
):
|
||||||
|
"""A complete forward schedule; Include recv fwd --> cal fwd --> send fwd;
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduled_node:
|
||||||
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||||
|
model_chunk_id (int): The current model chunk idx;
|
||||||
|
input_obj (Optional[dict]): x;
|
||||||
|
criterion (Callable): loss function;
|
||||||
|
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||||
|
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Nothing.
|
||||||
|
"""
|
||||||
# Step1: recv fwd
|
# Step1: recv fwd
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# is first stage; get input from func param
|
# is first stage; get input from func param
|
||||||
@ -541,6 +569,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# output_obj: Union[dict, torch.Tensor],
|
# output_obj: Union[dict, torch.Tensor],
|
||||||
# output_obj_grad: Optional[dict],
|
# output_obj_grad: Optional[dict],
|
||||||
):
|
):
|
||||||
|
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduled_node:
|
||||||
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||||
|
model_chunk_id (int): The current model chunk idx;
|
||||||
|
Returns:
|
||||||
|
Nothing.
|
||||||
|
"""
|
||||||
|
|
||||||
# Step1: recv bwd
|
# Step1: recv bwd
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
||||||
@ -606,6 +644,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
# optimizer: OptimizerWrapper,
|
# optimizer: OptimizerWrapper,
|
||||||
):
|
):
|
||||||
|
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduled_node:
|
||||||
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||||
|
model_chunk_id (int): The current model chunk idx;
|
||||||
|
Returns:
|
||||||
|
Nothing.
|
||||||
|
"""
|
||||||
|
|
||||||
# get y & dy from buffer
|
# get y & dy from buffer
|
||||||
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
||||||
@ -629,7 +676,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
return_loss: bool = False,
|
return_loss: bool = False,
|
||||||
return_outputs: bool = False,
|
return_outputs: bool = False,
|
||||||
):
|
):
|
||||||
it = self.it
|
"""
|
||||||
|
Runs Zerobubble schedule, with communication between pipeline stages.
|
||||||
|
"""
|
||||||
|
it = 0
|
||||||
# while we still have schedules_node in self.schedules
|
# while we still have schedules_node in self.schedules
|
||||||
# print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n")
|
# print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n")
|
||||||
while it < len(self.schedules):
|
while it < len(self.schedules):
|
||||||
|
Loading…
Reference in New Issue
Block a user