[feat] add test run_fwd_bwd automatic scheduling;

This commit is contained in:
duanjunwen
2024-08-26 11:21:56 +00:00
parent fd5526b76e
commit 1d75045c37
4 changed files with 259 additions and 48 deletions

View File

@@ -176,6 +176,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
#################
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.recv_forward_buffer[model_chunk_id].append(None)
return None, []
################
@@ -188,6 +189,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# metadata_recv=self.tensor_metadata_recv
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles
else:
@@ -200,7 +202,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, []
################
@@ -214,7 +216,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# metadata_recv=self.tensor_metadata_recv
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
@@ -240,6 +242,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, []
################
@@ -252,6 +255,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# metadata_recv=self.grad_metadata_recv
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles
else:
@@ -261,6 +265,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; get loss from local
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.recv_backward_buffer[model_chunk_id].append(None)
return None, []
################
@@ -268,16 +273,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# self.comm.recv_backward recv bwd from prev stage;
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} output_tensor_grad {output_tensor_grad};\n buffer {self.recv_backward_buffer}")
# metadata_recv=self.grad_metadata_recv
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline.
For ZBV.
@@ -291,6 +296,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
if model_chunk_id == 0:
################
# chunk = 0 && is_last_stage
@@ -330,7 +336,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
For ZBV.
@@ -359,6 +365,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Send dx to PREV stage;
################
else:
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
prev_rank = self.stage_manager.get_prev_rank()
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
# send_metadata=self.send_grad_metadata
@@ -371,6 +378,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# hold dy to local_send_bwd_buffer;
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
self.local_send_backward_buffer.append(input_tensor_grad)
return []
@@ -379,6 +387,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Send dx to NEXT stage;
################
else:
print(
f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} send_backward_buffer {self.send_backward_buffer}"
)
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
next_rank = self.stage_manager.get_next_rank()
# print(f"send bwd input_tensor_grad {input_tensor_grad}")
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
@@ -413,6 +425,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
output_obj = model_chunk[model_chunk_id](input_obj)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@@ -463,6 +476,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# commom bwd step
# print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}")
# BUG:output_obj_grad is None
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; tensor {output_obj};\n grad_tensors {output_obj_grad};\n inputs {input_obj}\n")
torch.autograd.backward(
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
)
@@ -505,14 +519,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs: Optional[List[Any]] = None,
):
# Step1: recv fwd
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
# # first layer
# input_obj = input_obj
# else:
# # other layer
# input_obj, wait_handles = self.recv_forward(model_chunk_id)
# # print(f"recv input_obj {input_obj}")
# _wait_p2p(wait_handles)
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
# first layer
input_obj = input_obj
self.recv_forward_buffer[model_chunk_id].pop(0) # pop none
else:
# other layer
input_obj, wait_handles = self.recv_forward(model_chunk_id)
# print(f"recv input_obj {input_obj}")
_wait_p2p(wait_handles)
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
# Step2: fwd step
output_obj = self.forward_step(
model_chunk=model_chunk,
@@ -522,6 +543,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss,
outputs=outputs,
)
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
# add input and output object for backward b
@@ -532,7 +554,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.output_tensors_dw[model_chunk_id].append(output_obj)
# Step3: send fwd
send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
# add output to send_fwd_buffer
self.send_forward_buffer[model_chunk_id].append(output_obj)
# send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
def schedule_b(
self,
@@ -545,17 +569,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# output_obj_grad: Optional[dict],
):
# Step1: recv bwd
# not first stage and chunk 1
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad, recv_bwd_handles = None, []
# print(f"recv output_tensor_grad {output_tensor_grad}")
else:
output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
# print(f"recv output_tensor_grad {output_tensor_grad}")
# # not first stage and chunk 1
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# output_tensor_grad, recv_bwd_handles = None, []
# # print(f"recv output_tensor_grad {output_tensor_grad}")
# else:
# output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
# # print(f"recv output_tensor_grad {output_tensor_grad}")
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n")
# get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop()
output_obj = self.output_tensors[model_chunk_id].pop()
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@@ -565,9 +592,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
_wait_p2p(recv_bwd_handles)
# _wait_p2p(recv_bwd_handles)
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
# Step2: bwd step
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}")
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
@@ -576,23 +606,23 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj=output_obj,
output_obj_grad=output_tensor_grad,
)
print(f"input_object_grad {input_object_grad}")
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}")
# Step3: send bwd
send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
# send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
def schedule_w(
self,
scheduled_node,
non_w_pending,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
# optimizer: OptimizerWrapper,
):
# get y & dy from buffer
output_obj = self.output_tensors_dw[model_chunk_id].pop()
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop()
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
self.backward_w_step(
model_chunk=model_chunk,
@@ -605,6 +635,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
input_obj: Optional[dict],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
@@ -615,19 +646,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# while we still have schedules_node in self.schedules
while it < len(self.schedules):
scheduled_node = self.schedules[it]
print(f"it {it}; scheduled_node {scheduled_node};")
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
if scheduled_node.type == "RECV_FORWARD":
self.recv_forward()
self.recv_forward(scheduled_node.chunk)
elif scheduled_node.type == "RECV_BACKWARD":
self.recv_backward()
self.recv_backward(scheduled_node.chunk)
elif scheduled_node.type == "SEND_FORWARD":
self.send_forward()
self.send_forward(scheduled_node.chunk)
elif scheduled_node.type == "SEND_BACKWARD":
self.send_backward()
elif scheduled_node.type == "F":
self.schedule_f()
self.send_backward(scheduled_node.chunk)
if scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
input_obj=input_obj,
criterion=criterion,
accum_loss=return_loss,
outputs=return_outputs,
)
elif scheduled_node.type == "B":
self.schedule_b()
self.schedule_b(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
)
elif scheduled_node.type == "W":
self.schedule_w()
self.schedule_w(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
)
it += 1