From 26783776f166d6b59611980d5760f68c2054d851 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 20 Sep 2024 06:41:19 +0000 Subject: [PATCH] [fix] fix input_tensors buffer append input_obj(dict) --> Tuple (microbatch, input_obj) , and all bwd b related cal logic; --- .../pipeline/schedule/zero_bubble_pp.py | 60 ++++++++++--------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 09ea4000c..d6aee7c1e 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -458,6 +458,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, + micro_batch: Optional[dict], input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], @@ -468,7 +469,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; optimizer (OptimizerWrapper): Optimizer to update the model - input_obj (Optional[dict]): x. + input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj) output_obj (Union[dict, torch.Tensor]): y. output_obj_grad (dict): dy. @@ -477,10 +478,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ # calculate bwd b step ; only dx = w*dy; - # Retain the grad on the input_obj. - if input_obj is None: - return None - else: + # Retain the grad on the input_obj. No need retain_grad microbatch + if input_obj is not None: tree_map(retain_grad, input_obj) # x, y, dy list for backward_by_grad; Type: list[tensor]; @@ -488,22 +487,28 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_ = [] output_obj_grad_ = [] - # get x from input_obj to input_obj_ - for k, v in input_obj.items(): - if v.requires_grad: - input_obj_.append(input_obj[k]) - - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # loss backward; output_obj is loss; so output_obj_grad should be None + # For chunk 0 stage 0, use micro_batch as input_obj_ + if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): + for k, v in micro_batch.items(): + if v.requires_grad: + input_obj_.append(micro_batch[k]) + output_obj_.append(output_obj[k]) # y + output_obj_grad_.append(output_obj_grad[k]) # dy + # For loss backward; output_obj is loss; output_obj_grad should be None + elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): assert output_obj_grad is None - output_obj_grad_.append(output_obj_grad) # None + for k, v in input_obj.items(): + if v.requires_grad: + input_obj_.append(input_obj[k]) output_obj_.append(output_obj) # LOSS - + output_obj_grad_.append(output_obj_grad) # None + # For other chunk stage, use input_obj as input_obj_; else: for k, v in input_obj.items(): if v.requires_grad: - output_obj_.append(output_obj[k]) - output_obj_grad_.append(output_obj_grad[k]) + input_obj_.append(input_obj[k]) + output_obj_.append(output_obj[k]) # y + output_obj_grad_.append(output_obj_grad[k]) # dy optimizer.backward_by_grad( tensor=output_obj_, @@ -512,9 +517,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): retain_graph=True, ) - # format output_obj_grad - if input_obj is not None: - input_obj_grad = {} + # Format output_obj_grad + input_obj_grad = {} + if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): + for k, v in micro_batch.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + else: for k, v in input_obj.items(): if isinstance(v, torch.Tensor) and v.grad is not None: input_obj_grad[k] = v.grad @@ -551,10 +560,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_.append(output_obj) # LOSS output_obj_grad_.append(None) # None else: - # for k, v in input_obj.items(): - # if v.requires_grad: - # output_obj_.append(output_obj[k]) - # output_obj_grad_.append(output_obj_grad[k]) for k, v in output_obj.items(): if v.requires_grad: output_obj_.append(output_obj[k]) @@ -634,10 +639,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): tree_map(deallocate, deallocate_output_obj) # add input and output object for backward b - if input_obj is not None: - self.input_tensors[model_chunk_id].append(input_obj) - else: - self.input_tensors[model_chunk_id].append(micro_batch) + + self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) # for bwd b&w, we only need the graph(grad_fn) of output_obj # Do not deallocate loss, deallocate other output_obj; @@ -703,7 +706,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) # get input and output object from buffer; - input_obj = self.input_tensors[model_chunk_id].pop(0) + micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) # save output_tensor_grad for dw @@ -719,6 +722,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk=model_chunk, model_chunk_id=model_chunk_id, optimizer=optimizer, + micro_batch=micro_batch, input_obj=input_obj, output_obj=output_obj, output_obj_grad=output_tensor_grad,