diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 50a30be1b..b641eb364 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -169,8 +169,8 @@ def clone(x: Any) -> Any: return x -def deallocate(x: Any) -> Any: - """Call deallocate() on a tensor. +def release_tensor_data(x: Any) -> Any: + """Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn. Args: x (Any): Object to be called. diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index bc2b0b7bf..9771277e2 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -14,12 +14,12 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from ._utils import ( clone, - deallocate, detach, get_batch_size, get_micro_batch, merge_batch, model_forward, + release_tensor_data, require_grad, retain_grad, to_device, @@ -488,8 +488,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): assert output_obj_grad is None input_obj_, _ = tree_flatten(input_obj) - output_obj_, _ = tree_flatten(output_obj) # LOSS - output_obj_grad_, _ = tree_flatten(output_obj_grad) # None + output_obj_.append(output_obj) # LOSS + output_obj_grad_.append(output_obj_grad) # None # For other chunk stage, use input_obj as input_obj_; else: @@ -614,20 +614,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs=outputs, ) - # Step3: deallocate output for bwd b & w; (do not detach output) + # Step3: release_tensor_data output for bwd b & w; (do not detach output) deallocate_output_obj = tree_map(clone, output_obj) if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # We should not deallocate bwd LOSS + # We should not release_tensor_data bwd LOSS pass else: - # deallocate output - tree_map(deallocate, deallocate_output_obj) + # release_tensor_data output + tree_map(release_tensor_data, deallocate_output_obj) # add input and output object for backward b self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) # for bwd b&w, we only need the graph(grad_fn) of output_obj - # Do not deallocate loss, deallocate other output_obj; + # Do not release_tensor_data loss, release_tensor_data other output_obj; if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): self.output_tensors[model_chunk_id].append(deallocate_output_obj) self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)