mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[feat] add optim backward_b_by_grad
This commit is contained in:
@@ -413,7 +413,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
optimizer: OptimizerWrapper,
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
@@ -447,7 +447,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
|
||||
else:
|
||||
# commom bwd step
|
||||
# BUG:output_obj_grad is None
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
)
|
||||
@@ -564,7 +563,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
scheduled_node,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
optimizer: OptimizerWrapper,
|
||||
# input_obj: Optional[dict],
|
||||
# output_obj: Union[dict, torch.Tensor],
|
||||
# output_obj_grad: Optional[dict],
|
||||
@@ -614,7 +613,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
optimizer=optimizer,
|
||||
input_obj=input_obj,
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_tensor_grad,
|
||||
@@ -715,6 +714,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
elif scheduled_node.type == "W":
|
||||
self.schedule_w(
|
||||
|
Reference in New Issue
Block a user