mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 07:47:05 +00:00
[fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;
This commit is contained in:
parent
7568b34626
commit
ce58d8e8bf
@ -475,8 +475,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
tree_map(retain_grad, input_obj)
|
tree_map(retain_grad, input_obj)
|
||||||
|
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# loss backward; output_obj is loss
|
# loss backward; output_obj is loss; so output_obj_grad should be None
|
||||||
output_obj_grad = None
|
assert output_obj_grad is None
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj,
|
tensor=output_obj,
|
||||||
grad=output_obj_grad,
|
grad=output_obj_grad,
|
||||||
@ -554,7 +555,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# not last stage; recv from next
|
# not last stage; recv from next
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
input_obj.requires_grad_()
|
|
||||||
|
# Here, let input_obj.requires_grad_()
|
||||||
|
tree_map(torch.Tensor.requires_grad_, input_obj)
|
||||||
|
|
||||||
# Step2: fwd step
|
# Step2: fwd step
|
||||||
output_obj = self.forward_step(
|
output_obj = self.forward_step(
|
||||||
|
Loading…
Reference in New Issue
Block a user