[fix] rm output.data after send fwd;

This commit is contained in:
duanjunwen
2024-09-03 14:12:17 +08:00
parent a48afc4a66
commit ab643c9af7
3 changed files with 25 additions and 49 deletions

View File

@@ -25,6 +25,24 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
req.wait()
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
if (out is None) or (not deallocate_pipeline_outputs):
print(
f"(out is None) or (not deallocate_pipeline_outputs): {(out is None) or (not deallocate_pipeline_outputs)}"
)
return
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
out.data.storage().resize_(0)
class ZeroBubbleVPipeScheduler(PipelineSchedule):
def __init__(
self,
@@ -562,10 +580,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
self.output_tensors[model_chunk_id].append(output_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
detached_output_obj = output_obj.clone()
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(detached_output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
# Step3: send fwd
# add output to send_fwd_buffer