mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
This commit is contained in:
@@ -46,13 +46,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: List[int]
|
||||
|
||||
self.collect_non_loss_data = None
|
||||
self.forward_only = None
|
||||
self.schedules = schedule
|
||||
# TODO: optim post valid
|
||||
self.do_post_validation = False
|
||||
# self.is_first_run = True
|
||||
# self.optimizer = None
|
||||
|
||||
# P2PMeta cache
|
||||
# self.enable_metadata_cache = enable_metadata_cache
|
||||
@@ -166,6 +162,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def communication_func_map(self, node_type: str):
|
||||
return {
|
||||
"SEND_FORWARD": self.send_forward,
|
||||
"RECV_FORWARD": self.recv_forward,
|
||||
"SEND_BACKWARD": self.send_backward,
|
||||
"RECV_BACKWARD": self.recv_backward,
|
||||
}[node_type]
|
||||
|
||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
For ZBV.
|
||||
@@ -439,10 +443,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
if model_chunk_id == 0:
|
||||
# bwd step
|
||||
# torch.autograd.backward(
|
||||
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
# )
|
||||
optimizer.backward_b_by_grad(
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
inputs=input_obj,
|
||||
@@ -451,8 +452,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
else:
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss
|
||||
# torch.autograd.backward(tensors=output_obj, grad_tensors=None, inputs=input_obj, retain_graph=True)
|
||||
optimizer.backward_b_by_grad(
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=None,
|
||||
inputs=input_obj,
|
||||
@@ -461,10 +461,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
else:
|
||||
# commom bwd step
|
||||
# torch.autograd.backward(
|
||||
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
# )
|
||||
optimizer.backward_b_by_grad(
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
inputs=input_obj,
|
||||
@@ -495,30 +492,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
"""
|
||||
# calculate bwd w step ; only dw = x*dy;
|
||||
if model_chunk_id == 0:
|
||||
# torch.autograd.backward(
|
||||
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())
|
||||
# )
|
||||
optimizer.backward_w_by_grad(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
retain_graph=False,
|
||||
)
|
||||
|
||||
else:
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# torch.autograd.backward(tensors=output_obj_grad, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters()))
|
||||
optimizer.backward_w_by_grad(
|
||||
tensors=output_obj, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters())
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=None,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
retain_graph=False,
|
||||
)
|
||||
else:
|
||||
# torch.autograd.backward(
|
||||
# tensors=output_obj,
|
||||
# grad_tensors=output_obj_grad,
|
||||
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
# )
|
||||
|
||||
optimizer.backward_w_by_grad(
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
retain_graph=False,
|
||||
)
|
||||
|
||||
def schedule_f(
|
||||
@@ -718,17 +712,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||
|
||||
it = 0
|
||||
# while we still have schedules_node in self.schedules
|
||||
while it < len(self.schedules):
|
||||
for it in range(len(self.schedules)):
|
||||
scheduled_node = self.schedules[it]
|
||||
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
|
||||
# communication
|
||||
if scheduled_node.type == "RECV_FORWARD":
|
||||
self.recv_forward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "SEND_FORWARD":
|
||||
self.send_forward(scheduled_node.chunk)
|
||||
communication_func = self.communication_func_map(scheduled_node.type)
|
||||
communication_func(scheduled_node.chunk)
|
||||
if scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
@@ -738,7 +729,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
it += 1
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
@@ -771,9 +761,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||
|
||||
it = 0
|
||||
# while we still have schedules_node in self.schedules
|
||||
while it < len(self.schedules):
|
||||
for it in range(len(self.schedules)):
|
||||
scheduled_node = self.schedules[it]
|
||||
|
||||
print(
|
||||
@@ -781,14 +770,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
)
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
if scheduled_node.type == "RECV_FORWARD":
|
||||
self.recv_forward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "RECV_BACKWARD":
|
||||
self.recv_backward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "SEND_FORWARD":
|
||||
self.send_forward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "SEND_BACKWARD":
|
||||
self.send_backward(scheduled_node.chunk)
|
||||
communication_func = self.communication_func_map(scheduled_node.type)
|
||||
communication_func(scheduled_node.chunk)
|
||||
|
||||
if scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
@@ -812,7 +796,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
it += 1
|
||||
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
|
Reference in New Issue
Block a user