mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 19:48:23 +00:00
[fix] fix traverse; traverse dict --> traverse tensor List;
This commit is contained in:
parent
fc8b016887
commit
83163fa70c
@ -480,26 +480,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
# For chunk 0 stage 0, use micro_batch as input_obj_
|
||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj_, _ = tree_flatten({k: v for k, v in micro_batch.items() if isinstance(v, torch.Tensor)})
|
||||
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
|
||||
output_obj_grad_, _ = tree_flatten(
|
||||
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
|
||||
) # dy
|
||||
input_obj_, _ = tree_flatten(micro_batch)
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
||||
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
assert output_obj_grad is None
|
||||
input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)})
|
||||
input_obj_, _ = tree_flatten(input_obj)
|
||||
output_obj_.append(output_obj) # LOSS
|
||||
output_obj_grad_.append(output_obj_grad) # None
|
||||
|
||||
# For other chunk stage, use input_obj as input_obj_;
|
||||
else:
|
||||
input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)})
|
||||
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
|
||||
output_obj_grad_, _ = tree_flatten(
|
||||
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
|
||||
) # dy
|
||||
input_obj_, _ = tree_flatten(input_obj)
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
||||
# filter item which is not torch.Tensor
|
||||
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
||||
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj_,
|
||||
@ -551,10 +552,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
output_obj_.append(output_obj) # LOSS
|
||||
output_obj_grad_.append(None) # None
|
||||
else:
|
||||
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
|
||||
output_obj_grad_, _ = tree_flatten(
|
||||
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
|
||||
) # dy
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
||||
# filter item which is not torch.Tensor
|
||||
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
||||
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj_,
|
||||
|
Loading…
Reference in New Issue
Block a user