mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[fix] fix zerobubble pp for shardformer type input;
This commit is contained in:
@@ -12,7 +12,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
|
||||
from ._utils import clone, detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
from .base import PipelineSchedule
|
||||
|
||||
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
||||
@@ -39,6 +39,20 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
|
||||
out.data.untyped_storage().resize_(0)
|
||||
|
||||
|
||||
def require_grad(tensor):
|
||||
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
||||
|
||||
This method should be called right after the output tensor has been
|
||||
sent to the next pipeline stage. At this point, the output tensor is
|
||||
only useful for its '.grad_fn' field, and not its '.data'.
|
||||
"""
|
||||
if tensor is None:
|
||||
return
|
||||
assert isinstance(tensor, torch.Tensor), "expected Tensor, found %s." % type(tensor).__name__
|
||||
assert tensor._base is None, "counter-productive to free a view of another tensor."
|
||||
tensor.requires_grad_()
|
||||
|
||||
|
||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -409,6 +423,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
micro_batch: Optional[dict],
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
@@ -427,18 +442,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
# Load input ids, attention mask and labels
|
||||
# micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
|
||||
# for the first stage, input_obj is None
|
||||
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
|
||||
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
|
||||
# Only attention_mask from micro_batch is used
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
# fwd calculate
|
||||
output_obj = model_chunk[model_chunk_id](input_obj)
|
||||
# fwd calculate
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
# fwd for ModuleList model
|
||||
if input_obj is None:
|
||||
output_obj = model_chunk[model_chunk_id](**micro_batch)
|
||||
else:
|
||||
output_obj = model_chunk[model_chunk_id](**input_obj)
|
||||
else:
|
||||
# fwd for shardformer
|
||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
loss = criterion(output_obj) / self.num_microbatch
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
@@ -472,19 +496,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# calculate bwd b step ; only dx = w*dy;
|
||||
|
||||
# Retain the grad on the input_obj.
|
||||
tree_map(retain_grad, input_obj)
|
||||
if input_obj is None:
|
||||
return None
|
||||
else:
|
||||
tree_map(retain_grad, input_obj)
|
||||
input_obj_ = input_obj["hidden_states"]
|
||||
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss; so output_obj_grad should be None
|
||||
assert output_obj_grad is None
|
||||
|
||||
output_obj_ = output_obj
|
||||
else:
|
||||
output_obj_ = output_obj["hidden_states"]
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad,
|
||||
inputs=input_obj,
|
||||
inputs=input_obj_,
|
||||
retain_graph=True,
|
||||
)
|
||||
return input_obj.grad
|
||||
return input_obj_.grad
|
||||
|
||||
def backward_w_step(
|
||||
self,
|
||||
@@ -511,8 +541,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss
|
||||
output_obj_grad = None
|
||||
output_obj_ = output_obj
|
||||
else:
|
||||
output_obj_ = output_obj["hidden_states"]
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
retain_graph=False,
|
||||
@@ -543,9 +576,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
# Step1: recv fwd
|
||||
if model_chunk_id == 0:
|
||||
# is first stage; get input from func param
|
||||
# is first stage; get input from microbatch
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj = micro_batch
|
||||
input_obj = None
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
else:
|
||||
@@ -557,45 +590,68 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# Here, let input_obj.requires_grad_()
|
||||
tree_map(torch.Tensor.requires_grad_, input_obj)
|
||||
if input_obj is not None:
|
||||
tree_map(require_grad, input_obj)
|
||||
|
||||
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
|
||||
# tree_map(torch.Tensor.requires_grad_, micro_batch)
|
||||
|
||||
# Step2: fwd step
|
||||
output_obj = self.forward_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
micro_batch=micro_batch,
|
||||
input_obj=input_obj,
|
||||
criterion=criterion,
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
# Step3: deallocate output for bwd b & w; (do not detach output)
|
||||
deallocate_output_obj = tree_map(clone, output_obj)
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# We should not deallocate bwd LOSS
|
||||
pass
|
||||
else:
|
||||
# deallocate output
|
||||
tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), deallocate_output_obj)
|
||||
|
||||
# add input and output object for backward b
|
||||
if input_obj is not None:
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
else:
|
||||
self.input_tensors[model_chunk_id].append(micro_batch)
|
||||
|
||||
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||
# Do not deallocate loss, deallocate other output_obj;
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
|
||||
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
|
||||
else:
|
||||
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
|
||||
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
|
||||
|
||||
# Step4: detach output for send fwd;
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# We should not detach bwd LOSS
|
||||
pass
|
||||
else:
|
||||
detached_output_obj = output_obj.clone().detach()
|
||||
# detach output
|
||||
output_obj = tree_map(detach, output_obj)
|
||||
|
||||
# Step3: send fwd
|
||||
# add output to send_fwd_buffer
|
||||
if model_chunk_id == 0:
|
||||
if model_chunk_id == 0: # chunk 0
|
||||
# is last stage; send to local_send_forward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.local_send_forward_buffer.append(detached_output_obj)
|
||||
self.local_send_forward_buffer.append(output_obj)
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||
else:
|
||||
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
|
||||
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||
else: # chunk 1
|
||||
# is first stage; end of fwd; do nothing
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
pass
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||
|
||||
# add input and output object for backward b
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||
deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True)
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
# add output object for backward w
|
||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||
|
||||
def schedule_b(
|
||||
self,
|
||||
@@ -603,9 +659,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
optimizer: OptimizerWrapper,
|
||||
# input_obj: Optional[dict],
|
||||
# output_obj: Union[dict, torch.Tensor],
|
||||
# output_obj_grad: Optional[dict],
|
||||
):
|
||||
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
||||
|
||||
@@ -616,20 +669,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
|
||||
# Step1: recv bwd
|
||||
if model_chunk_id == 0:
|
||||
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||
# chunk 0 not last stage; recv output_grad from recv_backward_buffer
|
||||
# chunk0 not last stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
else:
|
||||
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
output_tensor_grad = None
|
||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
@@ -645,7 +697,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# we save output_tensor_grad here
|
||||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||
|
||||
# _wait_p2p(recv_bwd_handles)
|
||||
# Step2: bwd step
|
||||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
@@ -777,8 +828,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
communication_func(scheduled_node.chunk)
|
||||
|
||||
if scheduled_node.type == "F":
|
||||
elif scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
|
Reference in New Issue
Block a user