[feat] add optim backward_b_by_grad

This commit is contained in:
duanjunwen
2024-08-29 03:16:59 +00:00
parent b1419ef76a
commit 4c4b01b859
3 changed files with 178 additions and 6 deletions

View File

@@ -413,7 +413,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
# optimizer: OptimizerWrapper,
optimizer: OptimizerWrapper,
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
@@ -447,7 +447,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
else:
# commom bwd step
# BUG:output_obj_grad is None
torch.autograd.backward(
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
)
@@ -564,7 +563,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
scheduled_node,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
# optimizer: OptimizerWrapper,
optimizer: OptimizerWrapper,
# input_obj: Optional[dict],
# output_obj: Union[dict, torch.Tensor],
# output_obj_grad: Optional[dict],
@@ -614,7 +613,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
# optimizer: OptimizerWrapper,
optimizer=optimizer,
input_obj=input_obj,
output_obj=output_obj,
output_obj_grad=output_tensor_grad,
@@ -715,6 +714,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
elif scheduled_node.type == "W":
self.schedule_w(