mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[fix] fix send_tensor_metadata & send_grad_metadata;
This commit is contained in:
@@ -64,8 +64,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
# P2PMeta cache
|
||||
self.enable_metadata_cache = enable_metadata_cache
|
||||
self.send_tensor_metadata = [True, True]
|
||||
self.send_grad_metadata = [True, True]
|
||||
|
||||
# check send_tensor_metadata, send_grad_metadata
|
||||
# pp4 as sample, we should follow this meta strategy
|
||||
# send_tensor_meta(fwd) send_grad_meta(bwd)
|
||||
# chunk0 | chunk1 chunk0 | chunk 1
|
||||
# stage 0 T | F F | T
|
||||
# stage 1 T | T T | T
|
||||
# stage 2 T | T T | T
|
||||
# stage 3 F | T F | T
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.send_tensor_metadata = [True, False]
|
||||
self.send_grad_metadata = [False, True]
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.send_tensor_metadata = [False, True]
|
||||
self.send_grad_metadata = [True, False]
|
||||
else:
|
||||
self.send_tensor_metadata = [True, True]
|
||||
self.send_grad_metadata = [True, True]
|
||||
|
||||
# meta cache buffer
|
||||
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
|
||||
self.grad_metadata_recv = [None, None]
|
||||
@@ -84,6 +101,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# init buffer
|
||||
self._free_buffers()
|
||||
|
||||
def _set_send_metadata_buffers(self, model_chunk_id):
|
||||
pass
|
||||
|
||||
def _free_buffers(self):
|
||||
# free local buffer
|
||||
# two dim array, first dim is the model chunk, second dim is the microbatch queue
|
||||
@@ -285,7 +305,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
# return None, []
|
||||
return []
|
||||
|
||||
################
|
||||
@@ -300,7 +319,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
# return output_tensor_grad, wait_handles
|
||||
return wait_handles
|
||||
|
||||
else:
|
||||
@@ -345,6 +363,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# do nothing; hold y on local_send_forward_buffer
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return []
|
||||
|
||||
################
|
||||
@@ -368,6 +387,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return []
|
||||
|
||||
################
|
||||
@@ -403,6 +423,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# do nothing; cause u are the first chunk in first stage; bwd end
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return []
|
||||
|
||||
################
|
||||
@@ -425,6 +446,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return []
|
||||
|
||||
################
|
||||
@@ -889,7 +911,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
|
||||
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
|
Reference in New Issue
Block a user