mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fix] rm output.data after send fwd;
This commit is contained in:
@@ -25,6 +25,24 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
||||
req.wait()
|
||||
|
||||
|
||||
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
|
||||
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
||||
|
||||
This method should be called right after the output tensor has been
|
||||
sent to the next pipeline stage. At this point, the output tensor is
|
||||
only useful for its '.grad_fn' field, and not its '.data'.
|
||||
"""
|
||||
if (out is None) or (not deallocate_pipeline_outputs):
|
||||
print(
|
||||
f"(out is None) or (not deallocate_pipeline_outputs): {(out is None) or (not deallocate_pipeline_outputs)}"
|
||||
)
|
||||
return
|
||||
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
|
||||
assert out._base is None, "counter-productive to free a view of another tensor."
|
||||
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
|
||||
out.data.storage().resize_(0)
|
||||
|
||||
|
||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -562,10 +580,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
)
|
||||
# add input and output object for backward b
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
|
||||
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||
detached_output_obj = output_obj.clone()
|
||||
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
|
||||
self.output_tensors[model_chunk_id].append(detached_output_obj)
|
||||
# add output object for backward w
|
||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
|
||||
|
||||
# Step3: send fwd
|
||||
# add output to send_fwd_buffer
|
||||
|
Reference in New Issue
Block a user