mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-06 20:10:08 +00:00
[fix] fix traverse; traverse dict --> traverse tensor List;
This commit is contained in:
parent
fc8b016887
commit
83163fa70c
@ -480,26 +480,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
|
|
||||||
# For chunk 0 stage 0, use micro_batch as input_obj_
|
# For chunk 0 stage 0, use micro_batch as input_obj_
|
||||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
input_obj_, _ = tree_flatten({k: v for k, v in micro_batch.items() if isinstance(v, torch.Tensor)})
|
input_obj_, _ = tree_flatten(micro_batch)
|
||||||
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
output_obj_grad_, _ = tree_flatten(
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
|
|
||||||
) # dy
|
|
||||||
|
|
||||||
# For loss backward; output_obj is loss; output_obj_grad should be None
|
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||||
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
assert output_obj_grad is None
|
assert output_obj_grad is None
|
||||||
input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)})
|
input_obj_, _ = tree_flatten(input_obj)
|
||||||
output_obj_.append(output_obj) # LOSS
|
output_obj_.append(output_obj) # LOSS
|
||||||
output_obj_grad_.append(output_obj_grad) # None
|
output_obj_grad_.append(output_obj_grad) # None
|
||||||
|
|
||||||
# For other chunk stage, use input_obj as input_obj_;
|
# For other chunk stage, use input_obj as input_obj_;
|
||||||
else:
|
else:
|
||||||
input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)})
|
input_obj_, _ = tree_flatten(input_obj)
|
||||||
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
output_obj_grad_, _ = tree_flatten(
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
|
|
||||||
) # dy
|
# filter item which is not torch.Tensor
|
||||||
|
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj_,
|
tensor=output_obj_,
|
||||||
@ -551,10 +552,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
output_obj_.append(output_obj) # LOSS
|
output_obj_.append(output_obj) # LOSS
|
||||||
output_obj_grad_.append(None) # None
|
output_obj_grad_.append(None) # None
|
||||||
else:
|
else:
|
||||||
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
output_obj_grad_, _ = tree_flatten(
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
|
|
||||||
) # dy
|
# filter item which is not torch.Tensor
|
||||||
|
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj_,
|
tensor=output_obj_,
|
||||||
|
Loading…
Reference in New Issue
Block a user