[fix] fix zerobubble pp for shardformer type input;

This commit is contained in:
duanjunwen
2024-09-18 07:14:34 +00:00
parent 9bc3b6e220
commit 3dbad102cf
3 changed files with 224 additions and 66 deletions

View File

@@ -12,7 +12,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
from ._utils import clone, detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
@@ -39,6 +39,20 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
out.data.untyped_storage().resize_(0)
def require_grad(tensor):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
if tensor is None:
return
assert isinstance(tensor, torch.Tensor), "expected Tensor, found %s." % type(tensor).__name__
assert tensor._base is None, "counter-productive to free a view of another tensor."
tensor.requires_grad_()
class ZeroBubbleVPipeScheduler(PipelineSchedule):
def __init__(
self,
@@ -409,6 +423,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
micro_batch: Optional[dict],
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
@@ -427,18 +442,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
# Load input ids, attention mask and labels
# micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
output_obj = model_chunk[model_chunk_id](input_obj)
# fwd calculate
if isinstance(model_chunk, ModuleList):
# fwd for ModuleList model
if input_obj is None:
output_obj = model_chunk[model_chunk_id](**micro_batch)
else:
output_obj = model_chunk[model_chunk_id](**input_obj)
else:
# fwd for shardformer
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
loss = criterion(output_obj) / self.num_microbatch
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
@@ -472,19 +496,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# calculate bwd b step ; only dx = w*dy;
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
if input_obj is None:
return None
else:
tree_map(retain_grad, input_obj)
input_obj_ = input_obj["hidden_states"]
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss; so output_obj_grad should be None
assert output_obj_grad is None
output_obj_ = output_obj
else:
output_obj_ = output_obj["hidden_states"]
optimizer.backward_by_grad(
tensor=output_obj,
tensor=output_obj_,
grad=output_obj_grad,
inputs=input_obj,
inputs=input_obj_,
retain_graph=True,
)
return input_obj.grad
return input_obj_.grad
def backward_w_step(
self,
@@ -511,8 +541,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
output_obj_grad = None
output_obj_ = output_obj
else:
output_obj_ = output_obj["hidden_states"]
optimizer.backward_by_grad(
tensor=output_obj,
tensor=output_obj_,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
@@ -543,9 +576,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# Step1: recv fwd
if model_chunk_id == 0:
# is first stage; get input from func param
# is first stage; get input from microbatch
if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = micro_batch
input_obj = None
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
else:
@@ -557,45 +590,68 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
# Here, let input_obj.requires_grad_()
tree_map(torch.Tensor.requires_grad_, input_obj)
if input_obj is not None:
tree_map(require_grad, input_obj)
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
# tree_map(torch.Tensor.requires_grad_, micro_batch)
# Step2: fwd step
output_obj = self.forward_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
micro_batch=micro_batch,
input_obj=input_obj,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
# Step3: deallocate output for bwd b & w; (do not detach output)
deallocate_output_obj = tree_map(clone, output_obj)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not deallocate bwd LOSS
pass
else:
# deallocate output
tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), deallocate_output_obj)
# add input and output object for backward b
if input_obj is not None:
self.input_tensors[model_chunk_id].append(input_obj)
else:
self.input_tensors[model_chunk_id].append(micro_batch)
# for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not deallocate loss, deallocate other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
else:
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
# Step4: detach output for send fwd;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not detach bwd LOSS
pass
else:
detached_output_obj = output_obj.clone().detach()
# detach output
output_obj = tree_map(detach, output_obj)
# Step3: send fwd
# add output to send_fwd_buffer
if model_chunk_id == 0:
if model_chunk_id == 0: # chunk 0
# is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(detached_output_obj)
self.local_send_forward_buffer.append(output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
else:
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
self.send_forward_buffer[model_chunk_id].append(output_obj)
else: # chunk 1
# is first stage; end of fwd; do nothing
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(output_obj)
self.send_forward_buffer[model_chunk_id].append(output_obj)
def schedule_b(
self,
@@ -603,9 +659,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
# input_obj: Optional[dict],
# output_obj: Union[dict, torch.Tensor],
# output_obj_grad: Optional[dict],
):
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
@@ -616,20 +669,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Returns:
Nothing.
"""
# Step1: recv bwd
if model_chunk_id == 0:
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# chunk 0 not last stage; recv output_grad from recv_backward_buffer
# chunk0 not last stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
else:
# chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad = None
# chunk1, not first stage; recv output_grad from recv_backward_buffer
# chunk1, not first stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
@@ -645,7 +697,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# _wait_p2p(recv_bwd_handles)
# Step2: bwd step
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
@@ -777,8 +828,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# communication
communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk)
if scheduled_node.type == "F":
elif scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,