[fix] fix optim bwd;

This commit is contained in:
duanjunwen
2024-09-02 11:19:42 +00:00
parent 77fe44286c
commit 591a13bf7e
3 changed files with 81 additions and 65 deletions

View File

@@ -441,27 +441,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if model_chunk_id == 0:
# bwd step
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=output_obj_grad,
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
else:
if self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=None,
optimizer.backward_by_grad(
tensor=output_obj,
grad=None,
inputs=input_obj,
retain_graph=True,
)
else:
# commom bwd step
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=output_obj_grad,
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
@@ -490,25 +490,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
# calculate bwd w step ; only dw = x*dy;
if model_chunk_id == 0:
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=output_obj_grad,
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
else:
if self.stage_manager.is_first_stage(ignore_chunk=True):
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=None,
optimizer.backward_by_grad(
tensor=output_obj,
grad=None,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
else:
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=output_obj_grad,
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)