diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index 6f605d22c..94f8b90c1 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -58,14 +58,17 @@ class OptimizerWrapper: def backward_by_grad(self, tensor: Tensor, grad: Tensor): torch.autograd.backward(tensor, grad) - def backward_b_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True): + def backward_b_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True): """ - Performs a backward pass for dx, we only calculate dx = w*dy here + Performs a backward pass for dx or dw, + for dx, we only calculate dx = w*dy here + for dw, we only calculate dw = x*dy here Args: tensor (Tensor): y or loss of current chunk; grad_tensors (Tensor): dy of current chunk; - input_obj (Tensor): x of current chunk; + input_obj (Tensor): for dx, input_obj is x of current chunk; + for dw, input_obj is w of current chunk; retain_graph (bool): default to be True, we retain graph in backward_b """ torch.autograd.backward( @@ -75,23 +78,6 @@ class OptimizerWrapper: retain_graph=retain_graph, ) - def backward_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = False): - """ - Performs a backward pass for dw, we only calculate dw = x*dy here - - Args: - tensor (Tensor): y or loss of current chunk; - grad_tensors (Tensor): dy of current chunk; - input_obj (Tensor): w; - retain_graph (bool): default to be False, we release graph in backward_w - """ - torch.autograd.backward( - tensors=tensors, - grad_tensors=grad_tensors, - inputs=inputs, - retain_graph=retain_graph, - ) - def state_dict(self): """ Returns the optimizer state. diff --git a/colossalai/pipeline/schedule/v_schedule.py b/colossalai/pipeline/schedule/v_schedule.py index b5c255e50..9eebebdea 100644 --- a/colossalai/pipeline/schedule/v_schedule.py +++ b/colossalai/pipeline/schedule/v_schedule.py @@ -1,6 +1,32 @@ # Refer from Zero Bubble Pipeline Parallelism. # Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism # Paper: https://arxiv.org/abs/2401.10241 +# The following applies to all files unless otherwise noted: +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from collections import deque from dataclasses import dataclass diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index ee6ad3227..ef3977691 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -46,13 +46,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.last_batch_size: Optional[int] = None self.microbatch_offset: List[int] - self.collect_non_loss_data = None - self.forward_only = None self.schedules = schedule # TODO: optim post valid self.do_post_validation = False - # self.is_first_run = True - # self.optimizer = None # P2PMeta cache # self.enable_metadata_cache = enable_metadata_cache @@ -166,6 +162,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id + def communication_func_map(self, node_type: str): + return { + "SEND_FORWARD": self.send_forward, + "RECV_FORWARD": self.recv_forward, + "SEND_BACKWARD": self.send_backward, + "RECV_BACKWARD": self.recv_backward, + }[node_type] + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. For ZBV. @@ -439,10 +443,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if model_chunk_id == 0: # bwd step - # torch.autograd.backward( - # tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True - # ) - optimizer.backward_b_by_grad( + optimizer.backward_b_w_by_grad( tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, @@ -451,8 +452,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: if self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss - # torch.autograd.backward(tensors=output_obj, grad_tensors=None, inputs=input_obj, retain_graph=True) - optimizer.backward_b_by_grad( + optimizer.backward_b_w_by_grad( tensors=output_obj, grad_tensors=None, inputs=input_obj, @@ -461,10 +461,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: # commom bwd step - # torch.autograd.backward( - # tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True - # ) - optimizer.backward_b_by_grad( + optimizer.backward_b_w_by_grad( tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, @@ -495,30 +492,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ # calculate bwd w step ; only dw = x*dy; if model_chunk_id == 0: - # torch.autograd.backward( - # tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()) - # ) - optimizer.backward_w_by_grad( - tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()) + optimizer.backward_b_w_by_grad( + tensors=output_obj, + grad_tensors=output_obj_grad, + inputs=list(model_chunk[model_chunk_id].parameters()), + retain_graph=False, ) else: if self.stage_manager.is_first_stage(ignore_chunk=True): - # torch.autograd.backward(tensors=output_obj_grad, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters())) - optimizer.backward_w_by_grad( - tensors=output_obj, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters()) + optimizer.backward_b_w_by_grad( + tensors=output_obj, + grad_tensors=None, + inputs=list(model_chunk[model_chunk_id].parameters()), + retain_graph=False, ) else: - # torch.autograd.backward( - # tensors=output_obj, - # grad_tensors=output_obj_grad, - # inputs=list(model_chunk[model_chunk_id].parameters()), - # ) - - optimizer.backward_w_by_grad( + optimizer.backward_b_w_by_grad( tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()), + retain_graph=False, ) def schedule_f( @@ -718,17 +712,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None - it = 0 # while we still have schedules_node in self.schedules - while it < len(self.schedules): + for it in range(len(self.schedules)): scheduled_node = self.schedules[it] - if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: + if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}: # communication - if scheduled_node.type == "RECV_FORWARD": - self.recv_forward(scheduled_node.chunk) - elif scheduled_node.type == "SEND_FORWARD": - self.send_forward(scheduled_node.chunk) + communication_func = self.communication_func_map(scheduled_node.type) + communication_func(scheduled_node.chunk) if scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, @@ -738,7 +729,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) - it += 1 # return loss & output if outputs is not None: outputs = merge_batch(outputs) @@ -771,9 +761,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None - it = 0 # while we still have schedules_node in self.schedules - while it < len(self.schedules): + for it in range(len(self.schedules)): scheduled_node = self.schedules[it] print( @@ -781,14 +770,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication - if scheduled_node.type == "RECV_FORWARD": - self.recv_forward(scheduled_node.chunk) - elif scheduled_node.type == "RECV_BACKWARD": - self.recv_backward(scheduled_node.chunk) - elif scheduled_node.type == "SEND_FORWARD": - self.send_forward(scheduled_node.chunk) - elif scheduled_node.type == "SEND_BACKWARD": - self.send_backward(scheduled_node.chunk) + communication_func = self.communication_func_map(scheduled_node.type) + communication_func(scheduled_node.chunk) + if scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, @@ -812,7 +796,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) - it += 1 # return loss & output if outputs is not None: