mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[feat] support no_tp Linear for sharderformer.llama
This commit is contained in:
@@ -64,10 +64,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
# P2PMeta cache
|
||||
self.enable_metadata_cache = enable_metadata_cache
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
self.send_tensor_metadata = [True, True]
|
||||
self.send_grad_metadata = [True, True]
|
||||
# meta cache buffer
|
||||
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
|
||||
self.grad_metadata_recv = [None, None]
|
||||
|
||||
# P2P communication
|
||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
||||
@@ -235,10 +236,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
input_tensor, wait_handles = self.comm.recv_forward(
|
||||
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv
|
||||
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
# return input_tensor, wait_handles
|
||||
return wait_handles
|
||||
@@ -259,10 +260,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
input_tensor, wait_handles = self.comm.recv_forward(
|
||||
next_rank, metadata_recv=self.tensor_metadata_recv
|
||||
next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
# return input_tensor, wait_handles
|
||||
return wait_handles
|
||||
@@ -297,10 +298,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||
next_rank, metadata_recv=self.grad_metadata_recv
|
||||
next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
# return output_tensor_grad, wait_handles
|
||||
return wait_handles
|
||||
@@ -322,10 +323,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv
|
||||
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
# return output_tensor_grad, wait_handles
|
||||
return wait_handles
|
||||
@@ -359,9 +360,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_forward(
|
||||
output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata
|
||||
output_object=output_tensor,
|
||||
next_rank=next_rank,
|
||||
send_metadata=self.send_tensor_metadata[model_chunk_id],
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
else:
|
||||
@@ -380,9 +383,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_forward(
|
||||
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata
|
||||
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id]
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
||||
@@ -415,9 +418,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
# bwd chunk1 is left V;
|
||||
@@ -437,9 +440,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata
|
||||
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
def forward_step(
|
||||
@@ -662,6 +665,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
# print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};")
|
||||
|
||||
# Step3:
|
||||
# 3-1:detach output; detach output for send fwd;
|
||||
@@ -886,6 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||
for it in range(len(schedule)):
|
||||
scheduled_node = schedule[it]
|
||||
# print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};")
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
|
Reference in New Issue
Block a user