diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index a2215d0fc..50a30be1b 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -137,7 +137,7 @@ def require_grad(x: Any) -> None: Args: x (Any): Object to be called. """ - if isinstance(x, torch.Tensor) and x.requires_grad: + if isinstance(x, torch.Tensor) and not x.requires_grad: x.requires_grad_() diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 365125ba3..65bb49aa1 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -12,7 +12,18 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from ._utils import clone, detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from ._utils import ( + clone, + deallocate, + detach, + get_batch_size, + get_micro_batch, + merge_batch, + model_forward, + require_grad, + retain_grad, + to_device, +) from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} @@ -24,35 +35,6 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: req.wait() -def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): - """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. - - This method should be called right after the output tensor has been - sent to the next pipeline stage. At this point, the output tensor is - only useful for its '.grad_fn' field, and not its '.data'. - """ - if (out is None) or (not deallocate_pipeline_outputs): - return - assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ - assert out._base is None, "counter-productive to free a view of another tensor." - # out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) - out.data.untyped_storage().resize_(0) - - -def require_grad(tensor): - """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. - - This method should be called right after the output tensor has been - sent to the next pipeline stage. At this point, the output tensor is - only useful for its '.grad_fn' field, and not its '.data'. - """ - if tensor is None: - return - assert isinstance(tensor, torch.Tensor), "expected Tensor, found %s." % type(tensor).__name__ - assert tensor._base is None, "counter-productive to free a view of another tensor." - tensor.requires_grad_() - - class ZeroBubbleVPipeScheduler(PipelineSchedule): def __init__( self, @@ -590,7 +572,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) # Here, let input_obj.requires_grad_() - if input_obj is not None: + # if input_obj is not None: + if not isinstance(input_obj, torch.Tensor): tree_map(require_grad, input_obj) # Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd, @@ -614,7 +597,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): pass else: # deallocate output - tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), deallocate_output_obj) + tree_map(deallocate, deallocate_output_obj) # add input and output object for backward b if input_obj is not None: