mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-05 21:52:06 +00:00
[fix] fix require_grad & deallocate call;
This commit is contained in:
parent
1f5c7258aa
commit
6ee9584b9a
@ -137,7 +137,7 @@ def require_grad(x: Any) -> None:
|
|||||||
Args:
|
Args:
|
||||||
x (Any): Object to be called.
|
x (Any): Object to be called.
|
||||||
"""
|
"""
|
||||||
if isinstance(x, torch.Tensor) and x.requires_grad:
|
if isinstance(x, torch.Tensor) and not x.requires_grad:
|
||||||
x.requires_grad_()
|
x.requires_grad_()
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,18 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
|
|||||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
from ._utils import clone, detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
from ._utils import (
|
||||||
|
clone,
|
||||||
|
deallocate,
|
||||||
|
detach,
|
||||||
|
get_batch_size,
|
||||||
|
get_micro_batch,
|
||||||
|
merge_batch,
|
||||||
|
model_forward,
|
||||||
|
require_grad,
|
||||||
|
retain_grad,
|
||||||
|
to_device,
|
||||||
|
)
|
||||||
from .base import PipelineSchedule
|
from .base import PipelineSchedule
|
||||||
|
|
||||||
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
||||||
@ -24,35 +35,6 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
|||||||
req.wait()
|
req.wait()
|
||||||
|
|
||||||
|
|
||||||
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
|
|
||||||
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
|
||||||
|
|
||||||
This method should be called right after the output tensor has been
|
|
||||||
sent to the next pipeline stage. At this point, the output tensor is
|
|
||||||
only useful for its '.grad_fn' field, and not its '.data'.
|
|
||||||
"""
|
|
||||||
if (out is None) or (not deallocate_pipeline_outputs):
|
|
||||||
return
|
|
||||||
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
|
|
||||||
assert out._base is None, "counter-productive to free a view of another tensor."
|
|
||||||
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
|
|
||||||
out.data.untyped_storage().resize_(0)
|
|
||||||
|
|
||||||
|
|
||||||
def require_grad(tensor):
|
|
||||||
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
|
||||||
|
|
||||||
This method should be called right after the output tensor has been
|
|
||||||
sent to the next pipeline stage. At this point, the output tensor is
|
|
||||||
only useful for its '.grad_fn' field, and not its '.data'.
|
|
||||||
"""
|
|
||||||
if tensor is None:
|
|
||||||
return
|
|
||||||
assert isinstance(tensor, torch.Tensor), "expected Tensor, found %s." % type(tensor).__name__
|
|
||||||
assert tensor._base is None, "counter-productive to free a view of another tensor."
|
|
||||||
tensor.requires_grad_()
|
|
||||||
|
|
||||||
|
|
||||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -590,7 +572,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# Here, let input_obj.requires_grad_()
|
# Here, let input_obj.requires_grad_()
|
||||||
if input_obj is not None:
|
# if input_obj is not None:
|
||||||
|
if not isinstance(input_obj, torch.Tensor):
|
||||||
tree_map(require_grad, input_obj)
|
tree_map(require_grad, input_obj)
|
||||||
|
|
||||||
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
|
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
|
||||||
@ -614,7 +597,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# deallocate output
|
# deallocate output
|
||||||
tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), deallocate_output_obj)
|
tree_map(deallocate, deallocate_output_obj)
|
||||||
|
|
||||||
# add input and output object for backward b
|
# add input and output object for backward b
|
||||||
if input_obj is not None:
|
if input_obj is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user