mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[fix] fix zbv llama pp4
This commit is contained in:
@@ -226,7 +226,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
||||
#################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# return None, []
|
||||
return []
|
||||
|
||||
################
|
||||
@@ -241,7 +240,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
# return input_tensor, wait_handles
|
||||
return wait_handles
|
||||
|
||||
else:
|
||||
@@ -265,7 +263,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
# return input_tensor, wait_handles
|
||||
return wait_handles
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||
@@ -313,7 +310,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# do nothing; get loss from local
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# return None, []
|
||||
return []
|
||||
|
||||
################
|
||||
@@ -328,7 +324,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
# return output_tensor_grad, wait_handles
|
||||
return wait_handles
|
||||
|
||||
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||
@@ -665,7 +660,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
# print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};")
|
||||
|
||||
# Step3:
|
||||
# 3-1:detach output; detach output for send fwd;
|
||||
@@ -748,20 +742,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||
|
||||
# # save output_tensor_grad for dw
|
||||
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# # we save loss here
|
||||
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
|
||||
# else:
|
||||
# # we save output_tensor_grad here
|
||||
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||
# the_output_obj_grad = []
|
||||
# if isinstance(output_obj, dict):
|
||||
# for (k, v) in output_obj.items():
|
||||
# the_output_obj_grad.append(v.requires_grad)
|
||||
# else:
|
||||
# the_output_obj_grad.append(output_obj.requires_grad)
|
||||
|
||||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
@@ -804,20 +784,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
|
||||
# get y & dy from buffer
|
||||
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
||||
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
||||
WeightGradStore.pop(chunk=model_chunk_id)
|
||||
|
||||
# self.backward_w_step(
|
||||
# model_chunk=model_chunk,
|
||||
# model_chunk_id=model_chunk_id,
|
||||
# optimizer=optimizer,
|
||||
# output_obj=output_obj,
|
||||
# output_obj_grad=output_obj_grad,
|
||||
# )
|
||||
|
||||
def run_forward_only(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
@@ -890,7 +858,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||
for it in range(len(schedule)):
|
||||
scheduled_node = schedule[it]
|
||||
# print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};")
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
|
Reference in New Issue
Block a user