[fix] fix traverse; traverse dict --> traverse tensor List;

This commit is contained in:
duanjunwen 2024-09-25 06:38:11 +00:00
parent fc8b016887
commit 83163fa70c

View File

@ -480,26 +480,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# For chunk 0 stage 0, use micro_batch as input_obj_
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj_, _ = tree_flatten({k: v for k, v in micro_batch.items() if isinstance(v, torch.Tensor)})
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
output_obj_grad_, _ = tree_flatten(
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
) # dy
input_obj_, _ = tree_flatten(micro_batch)
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None
input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)})
input_obj_, _ = tree_flatten(input_obj)
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(output_obj_grad) # None
# For other chunk stage, use input_obj as input_obj_;
else:
input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)})
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
output_obj_grad_, _ = tree_flatten(
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
) # dy
input_obj_, _ = tree_flatten(input_obj)
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# filter item which is not torch.Tensor
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
optimizer.backward_by_grad(
tensor=output_obj_,
@ -551,10 +552,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(None) # None
else:
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
output_obj_grad_, _ = tree_flatten(
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
) # dy
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# filter item which is not torch.Tensor
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
optimizer.backward_by_grad(
tensor=output_obj_,