mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[fix] fix optim bwd;
This commit is contained in:
@@ -441,27 +441,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
if model_chunk_id == 0:
|
||||
# bwd step
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
grad=output_obj_grad,
|
||||
inputs=input_obj,
|
||||
retain_graph=True,
|
||||
)
|
||||
else:
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=None,
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
grad=None,
|
||||
inputs=input_obj,
|
||||
retain_graph=True,
|
||||
)
|
||||
|
||||
else:
|
||||
# commom bwd step
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
grad=output_obj_grad,
|
||||
inputs=input_obj,
|
||||
retain_graph=True,
|
||||
)
|
||||
@@ -490,25 +490,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
"""
|
||||
# calculate bwd w step ; only dw = x*dy;
|
||||
if model_chunk_id == 0:
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
grad=output_obj_grad,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
retain_graph=False,
|
||||
)
|
||||
|
||||
else:
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=None,
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
grad=None,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
retain_graph=False,
|
||||
)
|
||||
else:
|
||||
optimizer.backward_b_w_by_grad(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
grad=output_obj_grad,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
retain_graph=False,
|
||||
)
|
||||
|
Reference in New Issue
Block a user