[feat] split communication and calculation; fix pop empty send_bwd_buffer error;

This commit is contained in:
duanjunwen 2024-08-27 06:29:13 +00:00
parent 1d75045c37
commit 5e09c8b4e1
2 changed files with 75 additions and 85 deletions

View File

@ -176,7 +176,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u are chunk 0 in first rank, u have no prev rank; # do nothing; cause u are chunk 0 in first rank, u have no prev rank;
################# #################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
self.recv_forward_buffer[model_chunk_id].append(None)
return None, [] return None, []
################ ################
@ -186,24 +185,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank) input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank)
# metadata_recv=self.tensor_metadata_recv
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles return input_tensor, wait_handles
else: else:
################ ################
# chunk = 1 & is_last_stage # chunk = 1 & is_last_stage
# get y from local_send_forward_buffer as input # do nothing; cause u get y from local_send_forward_buffer in schedule f
################ ################
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
input_tensor = self.local_send_forward_buffer.pop(0) return None, []
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, []
################ ################
# chunk = 1 & not is_last_stage # chunk = 1 & not is_last_stage
@ -212,10 +203,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
input_tensor, wait_handles = self.comm.recv_forward(next_rank) input_tensor, wait_handles = self.comm.recv_forward(next_rank)
# metadata_recv=self.tensor_metadata_recv
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles return input_tensor, wait_handles
@ -236,14 +223,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# bwd chunk0 is right V; # bwd chunk0 is right V;
################ ################
# chunk = 0 & is_last_stage # chunk = 0 & is_last_stage
# get dy from local recv_bwd_buffer # do nothing; Already get dy from local_send_backward_buffer in schedule b
################ ################
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0) return None, []
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, []
################ ################
# chunk = 0 & not is_last_stage # chunk = 0 & not is_last_stage
@ -252,9 +235,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank) output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank)
# metadata_recv=self.grad_metadata_recv
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles return output_tensor_grad, wait_handles
@ -265,20 +245,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; get loss from local # do nothing; get loss from local
################ ################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
self.recv_backward_buffer[model_chunk_id].append(None)
return None, [] return None, []
################ ################
# chunk = 1 & not is_first_stage # chunk = 1 & not first stage
# self.comm.recv_backward recv bwd from prev stage; # recv_backward recv bwd from prev stage;
################ ################
else: else:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank) output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} output_tensor_grad {output_tensor_grad};\n buffer {self.recv_backward_buffer}")
# metadata_recv=self.grad_metadata_recv
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles return output_tensor_grad, wait_handles
@ -296,14 +271,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
""" """
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
if model_chunk_id == 0: if model_chunk_id == 0:
################ ################
# chunk = 0 && is_last_stage # chunk = 0 && is_last_stage
# hold y on local_send_forward_buffer # do nothing; hold y on local_send_forward_buffer
################ ################
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(output_tensor)
return [] return []
################ ################
@ -312,15 +285,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################ ################
else: else:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank)
# send_metadata=self.send_tensor_metadata
# self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles return send_handles
else: else:
################ ################
# chunk = 1 && is_first_stage # chunk = 1 && is_first_stage
# do nothing; cause you are the last chunk on last stage; # do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
################ ################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
return [] return []
@ -331,9 +303,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################ ################
else: else:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(output_tensor, prev_rank) send_handles = self.comm.send_forward(output_tensor, prev_rank)
# send_metadata=self.send_tensor_metadata
# self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles return send_handles
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
@ -355,7 +326,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################ ################
# chunk = 0 && is_first_stage # chunk = 0 && is_first_stage
# do nothing; cause u are the first chunk in first stage; bwd end # do nothing; cause u are the first chunk in first stage; bwd end
# send input_tensor_grad to local buffer;
################ ################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
return [] return []
@ -365,21 +335,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Send dx to PREV stage; # Send dx to PREV stage;
################ ################
else: else:
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank) send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
# send_metadata=self.send_grad_metadata
return send_handles return send_handles
# bwd chunk1 is left V; # bwd chunk1 is left V;
else: else:
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} self.send_backward_buffer {self.send_backward_buffer}")
################ ################
# chunk = 1 && is_last_stage # chunk = 1 && is_last_stage
# hold dy to local_send_bwd_buffer; # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
################ ################
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
self.local_send_backward_buffer.append(input_tensor_grad)
return [] return []
################ ################
@ -387,14 +355,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Send dx to NEXT stage; # Send dx to NEXT stage;
################ ################
else: else:
print(
f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} send_backward_buffer {self.send_backward_buffer}"
)
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
# print(f"send bwd input_tensor_grad {input_tensor_grad}") input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(input_tensor_grad, next_rank) send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
# send_metadata=self.send_grad_metadata
return send_handles return send_handles
def forward_step( def forward_step(
@ -519,18 +482,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs: Optional[List[Any]] = None, outputs: Optional[List[Any]] = None,
): ):
# Step1: recv fwd # Step1: recv fwd
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 0:
# # first layer # is first stage; get input from func param
# input_obj = input_obj if self.stage_manager.is_first_stage(ignore_chunk=True):
# else:
# # other layer
# input_obj, wait_handles = self.recv_forward(model_chunk_id)
# # print(f"recv input_obj {input_obj}")
# _wait_p2p(wait_handles)
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = input_obj input_obj = input_obj
self.recv_forward_buffer[model_chunk_id].pop(0) # pop none else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
else:
# is last stage; recv from local
if self.stage_manager.is_last_stage(ignore_chunk=True):
input_obj = self.local_send_forward_buffer.pop(0)
# not last stage; recv from next
else: else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
@ -555,8 +518,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Step3: send fwd # Step3: send fwd
# add output to send_fwd_buffer # add output to send_fwd_buffer
if model_chunk_id == 0:
# is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(output_obj)
else:
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.local_send_backward_buffer.append(output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(output_obj) self.send_forward_buffer[model_chunk_id].append(output_obj)
# send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
def schedule_b( def schedule_b(
self, self,
@ -569,13 +542,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# output_obj_grad: Optional[dict], # output_obj_grad: Optional[dict],
): ):
# Step1: recv bwd # Step1: recv bwd
# # not first stage and chunk 1 if model_chunk_id == 0:
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # chunk0 is last stage; recv output_grad from local_send_backward_buffer
# output_tensor_grad, recv_bwd_handles = None, [] if self.stage_manager.is_last_stage(ignore_chunk=True):
# # print(f"recv output_tensor_grad {output_tensor_grad}") output_tensor_grad = self.local_send_backward_buffer.pop(0)
# else: # chunk 0 not last stage; recv output_grad from recv_backward_buffer
# output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id) else:
# # print(f"recv output_tensor_grad {output_tensor_grad}") output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
else:
# chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# chunk1, not first stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n") # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n")
@ -593,11 +572,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# _wait_p2p(recv_bwd_handles) # _wait_p2p(recv_bwd_handles)
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
# Step2: bwd step # Step2: bwd step
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}")
input_object_grad = self.backward_b_step( input_object_grad = self.backward_b_step(
model_chunk=model_chunk, model_chunk=model_chunk,
model_chunk_id=model_chunk_id, model_chunk_id=model_chunk_id,
@ -609,7 +584,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}") # print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}")
# Step3: send bwd # Step3: send bwd
# send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad) if model_chunk_id == 0:
# do nothing; end of bwd;
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
# save input_object_grad to send_backward_buffer
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
else:
# send to local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_backward_buffer.append(input_object_grad)
# send to next
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad) self.send_backward_buffer[model_chunk_id].append(input_object_grad)
def schedule_w( def schedule_w(
@ -644,9 +631,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
): ):
it = self.it it = self.it
# while we still have schedules_node in self.schedules # while we still have schedules_node in self.schedules
# print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n")
while it < len(self.schedules): while it < len(self.schedules):
scheduled_node = self.schedules[it] scheduled_node = self.schedules[it]
print(f"it {it}; scheduled_node {scheduled_node};") print(
f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};"
)
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication # communication
if scheduled_node.type == "RECV_FORWARD": if scheduled_node.type == "RECV_FORWARD":

View File

@ -486,7 +486,7 @@ def test_run_fwd_bwd_base(
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0), ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0), ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
], ],
# stage 2 # stage 2
[ [
@ -547,7 +547,7 @@ def test_run_fwd_bwd_base(
# init model and input # init model and input
num_layers = 8 num_layers = 8
in_dim = out_dim = 8 in_dim = out_dim = 8
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") # print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
@ -578,9 +578,9 @@ def test_run_fwd_bwd_base(
for idx, sub_model in enumerate(model.layers): for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4: if idx == 3 or idx == 4:
local_chunk.append(sub_model) local_chunk.append(sub_model)
print( # print(
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" # f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
) # )
torch.cuda.synchronize() torch.cuda.synchronize()
scheduler.run_forward_backward( scheduler.run_forward_backward(