diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 307d1035c..0272cc113 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -480,26 +480,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # For chunk 0 stage 0, use micro_batch as input_obj_ if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj_, _ = tree_flatten({k: v for k, v in micro_batch.items() if isinstance(v, torch.Tensor)}) - output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y - output_obj_grad_, _ = tree_flatten( - {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} - ) # dy + input_obj_, _ = tree_flatten(micro_batch) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy # For loss backward; output_obj is loss; output_obj_grad should be None elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): assert output_obj_grad is None - input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) + input_obj_, _ = tree_flatten(input_obj) output_obj_.append(output_obj) # LOSS output_obj_grad_.append(output_obj_grad) # None # For other chunk stage, use input_obj as input_obj_; else: - input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) - output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y - output_obj_grad_, _ = tree_flatten( - {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} - ) # dy + input_obj_, _ = tree_flatten(input_obj) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + + # filter item which is not torch.Tensor + input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] optimizer.backward_by_grad( tensor=output_obj_, @@ -551,10 +552,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_.append(output_obj) # LOSS output_obj_grad_.append(None) # None else: - output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y - output_obj_grad_, _ = tree_flatten( - {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} - ) # dy + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + + # filter item which is not torch.Tensor + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] optimizer.backward_by_grad( tensor=output_obj_,