mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-17 08:51:59 +00:00
Merge pull request #6114 from duanjunwen/dev/zero_bubble
[Zerobubble] Support LinearWithAsyncCommunication for sharderformer policy
This commit is contained in:
commit
810cafb2f9
@ -432,7 +432,6 @@ def _communicate(
|
|||||||
overlap_p2p=overlap_p2p,
|
overlap_p2p=overlap_p2p,
|
||||||
send_first=send_first if send_first != None else True,
|
send_first=send_first if send_first != None else True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if metadata_recv is not None:
|
if metadata_recv is not None:
|
||||||
assert isinstance(metadata_recv, P2PMetadata)
|
assert isinstance(metadata_recv, P2PMetadata)
|
||||||
tree_spec = metadata_recv.tree_spec
|
tree_spec = metadata_recv.tree_spec
|
||||||
|
@ -64,10 +64,28 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
|
|
||||||
# P2PMeta cache
|
# P2PMeta cache
|
||||||
self.enable_metadata_cache = enable_metadata_cache
|
self.enable_metadata_cache = enable_metadata_cache
|
||||||
self.send_tensor_metadata = True
|
|
||||||
self.send_grad_metadata = True
|
# check send_tensor_metadata, send_grad_metadata
|
||||||
self.tensor_metadata_recv = None
|
# pp4 as sample, we should follow this meta strategy
|
||||||
self.grad_metadata_recv = None
|
# send_tensor_meta(fwd) send_grad_meta(bwd)
|
||||||
|
# chunk0 | chunk1 chunk0 | chunk 1
|
||||||
|
# stage 0 T | F F | T
|
||||||
|
# stage 1 T | T T | T
|
||||||
|
# stage 2 T | T T | T
|
||||||
|
# stage 3 F | T F | T
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
self.send_tensor_metadata = [True, False]
|
||||||
|
self.send_grad_metadata = [False, True]
|
||||||
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
self.send_tensor_metadata = [False, True]
|
||||||
|
self.send_grad_metadata = [True, False]
|
||||||
|
else:
|
||||||
|
self.send_tensor_metadata = [True, True]
|
||||||
|
self.send_grad_metadata = [True, True]
|
||||||
|
|
||||||
|
# meta cache buffer
|
||||||
|
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
|
||||||
|
self.grad_metadata_recv = [None, None]
|
||||||
|
|
||||||
# P2P communication
|
# P2P communication
|
||||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
||||||
@ -96,10 +114,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self.output_tensors_grad_dw = [[], []]
|
self.output_tensors_grad_dw = [[], []]
|
||||||
|
|
||||||
# buffer for communication
|
# buffer for communication
|
||||||
self.send_forward_buffer = [[], []]
|
self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
|
||||||
self.recv_forward_buffer = [[], []]
|
self.recv_forward_buffer = [
|
||||||
self.send_backward_buffer = [[], []]
|
[],
|
||||||
self.recv_backward_buffer = [[], []]
|
[],
|
||||||
|
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
|
||||||
|
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
|
||||||
|
self.recv_backward_buffer = [
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
|
||||||
|
|
||||||
# y buffer for local send fwd
|
# y buffer for local send fwd
|
||||||
self.local_send_forward_buffer = []
|
self.local_send_forward_buffer = []
|
||||||
@ -225,7 +249,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
||||||
#################
|
#################
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# return None, []
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -235,12 +258,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
else:
|
else:
|
||||||
prev_rank = self.stage_manager.get_prev_rank()
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
input_tensor, wait_handles = self.comm.recv_forward(
|
input_tensor, wait_handles = self.comm.recv_forward(
|
||||||
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv
|
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
|
||||||
)
|
)
|
||||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
|
||||||
# return input_tensor, wait_handles
|
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -259,12 +281,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
else:
|
else:
|
||||||
next_rank = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
input_tensor, wait_handles = self.comm.recv_forward(
|
input_tensor, wait_handles = self.comm.recv_forward(
|
||||||
next_rank, metadata_recv=self.tensor_metadata_recv
|
next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
|
||||||
)
|
)
|
||||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
|
||||||
# return input_tensor, wait_handles
|
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||||
@ -287,7 +308,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
# return None, []
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -297,12 +317,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
else:
|
else:
|
||||||
next_rank = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||||
next_rank, metadata_recv=self.grad_metadata_recv
|
next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
|
||||||
)
|
)
|
||||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
|
||||||
# return output_tensor_grad, wait_handles
|
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -312,7 +331,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; get loss from local
|
# do nothing; get loss from local
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# return None, []
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -322,12 +340,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
else:
|
else:
|
||||||
prev_rank = self.stage_manager.get_prev_rank()
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||||
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv
|
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
|
||||||
)
|
)
|
||||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
|
||||||
# return output_tensor_grad, wait_handles
|
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||||
@ -349,6 +366,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; hold y on local_send_forward_buffer
|
# do nothing; hold y on local_send_forward_buffer
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -359,9 +377,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
next_rank = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||||
send_handles = self.comm.send_forward(
|
send_handles = self.comm.send_forward(
|
||||||
output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata
|
output_object=output_tensor,
|
||||||
|
next_rank=next_rank,
|
||||||
|
send_metadata=self.send_tensor_metadata[model_chunk_id],
|
||||||
)
|
)
|
||||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return send_handles
|
return send_handles
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -370,6 +390,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
|
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -380,9 +401,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
prev_rank = self.stage_manager.get_prev_rank()
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||||
send_handles = self.comm.send_forward(
|
send_handles = self.comm.send_forward(
|
||||||
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata
|
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id]
|
||||||
)
|
)
|
||||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return send_handles
|
return send_handles
|
||||||
|
|
||||||
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
||||||
@ -405,6 +426,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; cause u are the first chunk in first stage; bwd end
|
# do nothing; cause u are the first chunk in first stage; bwd end
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -415,9 +437,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
prev_rank = self.stage_manager.get_prev_rank()
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||||
send_handles = self.comm.send_backward(
|
send_handles = self.comm.send_backward(
|
||||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
|
||||||
)
|
)
|
||||||
self.send_grad_metadata = not self.enable_metadata_cache
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return send_handles
|
return send_handles
|
||||||
|
|
||||||
# bwd chunk1 is left V;
|
# bwd chunk1 is left V;
|
||||||
@ -427,6 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -437,9 +460,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
next_rank = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||||
send_handles = self.comm.send_backward(
|
send_handles = self.comm.send_backward(
|
||||||
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata
|
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
|
||||||
)
|
)
|
||||||
self.send_grad_metadata = not self.enable_metadata_cache
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return send_handles
|
return send_handles
|
||||||
|
|
||||||
def forward_step(
|
def forward_step(
|
||||||
@ -519,8 +542,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
output_obj_grad_ = []
|
output_obj_grad_ = []
|
||||||
|
|
||||||
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
|
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
|
||||||
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
||||||
# return None
|
|
||||||
|
|
||||||
# For loss backward; output_obj is loss; output_obj_grad should be None
|
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
@ -633,9 +654,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# is first stage; get input from microbatch
|
# is first stage; get input from microbatch
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
input_obj = None
|
input_obj = None # (tensor, wait_handle)
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
|
for h in input_obj[1]:
|
||||||
|
h.wait()
|
||||||
|
input_obj = input_obj[0]
|
||||||
else:
|
else:
|
||||||
# is last stage; recv from local
|
# is last stage; recv from local
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
@ -643,7 +667,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# not last stage; recv from next
|
# not last stage; recv from next
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
|
for h in input_obj[1]:
|
||||||
|
h.wait()
|
||||||
|
input_obj = input_obj[0]
|
||||||
# Here, let input_obj.requires_grad_()
|
# Here, let input_obj.requires_grad_()
|
||||||
# if input_obj is not None:
|
# if input_obj is not None:
|
||||||
if not isinstance(input_obj, torch.Tensor):
|
if not isinstance(input_obj, torch.Tensor):
|
||||||
@ -689,10 +715,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
self.output_tensors[model_chunk_id].append(output_obj)
|
self.output_tensors[model_chunk_id].append(output_obj)
|
||||||
# self.output_tensors_dw[model_chunk_id].append(output_obj)
|
|
||||||
else:
|
else:
|
||||||
self.output_tensors[model_chunk_id].append(output_obj)
|
self.output_tensors[model_chunk_id].append(output_obj)
|
||||||
# self.output_tensors_dw[model_chunk_id].append(output_obj)
|
|
||||||
|
|
||||||
# add output to send_fwd_buffer
|
# add output to send_fwd_buffer
|
||||||
if model_chunk_id == 0: # chunk 0
|
if model_chunk_id == 0: # chunk 0
|
||||||
@ -732,6 +756,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# chunk0 not last stage; recv output_grad from recv_backward_buffer
|
# chunk0 not last stage; recv output_grad from recv_backward_buffer
|
||||||
else:
|
else:
|
||||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||||
|
for h in output_tensor_grad[1]:
|
||||||
|
h.wait()
|
||||||
|
output_tensor_grad = output_tensor_grad[0]
|
||||||
else:
|
else:
|
||||||
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
@ -739,25 +766,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||||
else:
|
else:
|
||||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||||
|
for h in output_tensor_grad[1]:
|
||||||
|
h.wait()
|
||||||
|
output_tensor_grad = output_tensor_grad[0]
|
||||||
|
|
||||||
# get input and output object from buffer;
|
# get input and output object from buffer;
|
||||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# # save output_tensor_grad for dw
|
|
||||||
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
||||||
# # we save loss here
|
|
||||||
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
|
|
||||||
# else:
|
|
||||||
# # we save output_tensor_grad here
|
|
||||||
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
|
||||||
# the_output_obj_grad = []
|
|
||||||
# if isinstance(output_obj, dict):
|
|
||||||
# for (k, v) in output_obj.items():
|
|
||||||
# the_output_obj_grad.append(v.requires_grad)
|
|
||||||
# else:
|
|
||||||
# the_output_obj_grad.append(output_obj.requires_grad)
|
|
||||||
|
|
||||||
input_object_grad = self.backward_b_step(
|
input_object_grad = self.backward_b_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
@ -800,20 +816,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
Returns:
|
Returns:
|
||||||
Nothing.
|
Nothing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# get y & dy from buffer
|
|
||||||
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
|
||||||
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
|
||||||
WeightGradStore.pop(chunk=model_chunk_id)
|
WeightGradStore.pop(chunk=model_chunk_id)
|
||||||
|
|
||||||
# self.backward_w_step(
|
|
||||||
# model_chunk=model_chunk,
|
|
||||||
# model_chunk_id=model_chunk_id,
|
|
||||||
# optimizer=optimizer,
|
|
||||||
# output_obj=output_obj,
|
|
||||||
# output_obj_grad=output_obj_grad,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def run_forward_only(
|
def run_forward_only(
|
||||||
self,
|
self,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
@ -890,7 +894,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_map[scheduled_node.type]
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
wait_handle = communication_func(scheduled_node.chunk)
|
wait_handle = communication_func(scheduled_node.chunk)
|
||||||
self.wait_handles.append(wait_handle)
|
# We wait recv handle in fwd step and bwd step. Here only need to wait for send handle
|
||||||
|
if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}:
|
||||||
|
self.wait_handles.append(wait_handle)
|
||||||
elif scheduled_node.type == "F":
|
elif scheduled_node.type == "F":
|
||||||
self.schedule_f(
|
self.schedule_f(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
@ -914,10 +920,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
|
# wait here to ensure all communication is done
|
||||||
for h in self.wait_handles:
|
for h in self.wait_handles:
|
||||||
for hh in h:
|
for hh in h:
|
||||||
hh.wait()
|
hh.wait()
|
||||||
|
|
||||||
# return loss & output
|
# return loss & output
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs = merge_batch(outputs)
|
outputs = merge_batch(outputs)
|
||||||
|
@ -6,6 +6,7 @@ import torch.distributed
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.kernel.kernel_loader import (
|
from colossalai.kernel.kernel_loader import (
|
||||||
FlashAttentionDaoLoader,
|
FlashAttentionDaoLoader,
|
||||||
@ -642,9 +643,7 @@ class RingAttention(torch.autograd.Function):
|
|||||||
max_seqlen_q = max_seqlen_kv = max_seqlen
|
max_seqlen_q = max_seqlen_kv = max_seqlen
|
||||||
cu_seqlens_half = cu_seqlens // 2
|
cu_seqlens_half = cu_seqlens // 2
|
||||||
max_seqlen_half = max_seqlen // 2
|
max_seqlen_half = max_seqlen // 2
|
||||||
|
|
||||||
misc_kwargs = {
|
misc_kwargs = {
|
||||||
"window_size": (-1, -1),
|
|
||||||
"alibi_slopes": None,
|
"alibi_slopes": None,
|
||||||
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
|
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
|
||||||
"dropout_p": dropout_p,
|
"dropout_p": dropout_p,
|
||||||
@ -652,6 +651,13 @@ class RingAttention(torch.autograd.Function):
|
|||||||
"softcap": 0.0,
|
"softcap": 0.0,
|
||||||
"return_softmax": False,
|
"return_softmax": False,
|
||||||
}
|
}
|
||||||
|
import flash_attn
|
||||||
|
|
||||||
|
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
|
||||||
|
misc_kwargs["window_size_left"] = -1
|
||||||
|
misc_kwargs["window_size_right"] = -1
|
||||||
|
else:
|
||||||
|
misc_kwargs["window_size"] = (-1, -1)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
RingAttention.HALF_INDICES is not None
|
RingAttention.HALF_INDICES is not None
|
||||||
@ -707,26 +713,39 @@ class RingAttention(torch.autograd.Function):
|
|||||||
|
|
||||||
# Helper to pass args to FA
|
# Helper to pass args to FA
|
||||||
def _forward(q, k, v, causal):
|
def _forward(q, k, v, causal):
|
||||||
(
|
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
|
||||||
_,
|
(out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward(
|
||||||
_,
|
q,
|
||||||
_,
|
k,
|
||||||
_,
|
v,
|
||||||
out,
|
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
|
||||||
softmax_lse,
|
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
|
||||||
_,
|
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
|
||||||
rng_state,
|
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
|
||||||
) = _flash_attn_forward(
|
causal=causal,
|
||||||
q,
|
**misc_kwargs,
|
||||||
k,
|
)
|
||||||
v,
|
else:
|
||||||
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
|
(
|
||||||
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
|
_,
|
||||||
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
|
_,
|
||||||
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
|
_,
|
||||||
causal=causal,
|
_,
|
||||||
**misc_kwargs,
|
out,
|
||||||
)
|
softmax_lse,
|
||||||
|
_,
|
||||||
|
rng_state,
|
||||||
|
) = _flash_attn_forward(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
|
||||||
|
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
|
||||||
|
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
|
||||||
|
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
|
||||||
|
causal=causal,
|
||||||
|
**misc_kwargs,
|
||||||
|
)
|
||||||
return out, softmax_lse, rng_state
|
return out, softmax_lse, rng_state
|
||||||
|
|
||||||
def _kv_comm(i):
|
def _kv_comm(i):
|
||||||
|
@ -191,7 +191,6 @@ class LlamaPipelineForwards:
|
|||||||
num_model_chunks=stage_manager.num_model_chunks,
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
)
|
)
|
||||||
assert num_ckpt_layers <= end_idx - start_idx
|
assert num_ckpt_layers <= end_idx - start_idx
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
@ -381,7 +381,6 @@ class MixtralPipelineForwards:
|
|||||||
output_router_logits,
|
output_router_logits,
|
||||||
use_cache,
|
use_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
@ -75,6 +75,8 @@ class BertPolicy(Policy):
|
|||||||
|
|
||||||
sp_partial_derived = sp_mode == "split_gather"
|
sp_partial_derived = sp_mode == "split_gather"
|
||||||
|
|
||||||
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
@ -97,6 +99,7 @@ class BertPolicy(Policy):
|
|||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -105,6 +108,7 @@ class BertPolicy(Policy):
|
|||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -113,6 +117,7 @@ class BertPolicy(Policy):
|
|||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -125,6 +130,7 @@ class BertPolicy(Policy):
|
|||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -138,6 +144,7 @@ class BertPolicy(Policy):
|
|||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -146,6 +153,97 @@ class BertPolicy(Policy):
|
|||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="output.dropout",
|
||||||
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
policy[BertEmbeddings] = ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="dropout",
|
||||||
|
target_module=col_nn.DropoutForReplicatedInput,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.enable_bias_gelu_fused:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_jit_fused_bert_intermediate_forward(),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=BertIntermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif use_zbv:
|
||||||
|
policy[BertLayer] = ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.self.query",
|
||||||
|
target_module=col_nn.LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"seq_parallel_mode": sp_mode,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.self.key",
|
||||||
|
target_module=col_nn.LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"seq_parallel_mode": sp_mode,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.self.value",
|
||||||
|
target_module=col_nn.LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"seq_parallel_mode": sp_mode,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.self.dropout",
|
||||||
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.output.dense",
|
||||||
|
target_module=col_nn.LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"seq_parallel_mode": sp_mode,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.output.dropout",
|
||||||
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="intermediate.dense",
|
||||||
|
target_module=col_nn.LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"seq_parallel_mode": sp_mode,
|
||||||
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="output.dense",
|
||||||
|
target_module=col_nn.LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"seq_parallel_mode": sp_mode,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
|
@ -9,6 +9,7 @@ from colossalai.shardformer.layer import (
|
|||||||
FusedRMSNorm,
|
FusedRMSNorm,
|
||||||
Linear1D_Col,
|
Linear1D_Col,
|
||||||
Linear1D_Row,
|
Linear1D_Row,
|
||||||
|
LinearWithGradAccum,
|
||||||
PaddingEmbedding,
|
PaddingEmbedding,
|
||||||
PaddingLMHead,
|
PaddingLMHead,
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
@ -104,7 +105,7 @@ class LlamaPolicy(Policy):
|
|||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=LlamaModel,
|
target_key=LlamaModel,
|
||||||
)
|
)
|
||||||
|
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
num_q_heads % tp_size == 0
|
num_q_heads % tp_size == 0
|
||||||
@ -191,6 +192,76 @@ class LlamaPolicy(Policy):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# not enable tp, replace layer to LinearWithGradAccum
|
||||||
|
elif use_zbv:
|
||||||
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if embedding_cls is not None:
|
if embedding_cls is not None:
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
@ -416,6 +487,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
|
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for sequence classification
|
# add a new item for sequence classification
|
||||||
new_item = {
|
new_item = {
|
||||||
@ -434,6 +506,25 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
|
# enable tp, replace layer to LinearWithGradAccum
|
||||||
|
elif use_zbv:
|
||||||
|
# add a new item for sequence classification
|
||||||
|
new_item = {
|
||||||
|
LlamaForSequenceClassification: ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="score",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs=dict(
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
|
||||||
# to be confirmed
|
# to be confirmed
|
||||||
if self.pipeline_stage_manager:
|
if self.pipeline_stage_manager:
|
||||||
# set None as default
|
# set None as default
|
||||||
|
@ -10,6 +10,7 @@ from colossalai.shardformer.layer import (
|
|||||||
FusedRMSNorm,
|
FusedRMSNorm,
|
||||||
Linear1D_Col,
|
Linear1D_Col,
|
||||||
Linear1D_Row,
|
Linear1D_Row,
|
||||||
|
LinearWithGradAccum,
|
||||||
PaddingEmbedding,
|
PaddingEmbedding,
|
||||||
PaddingLMHead,
|
PaddingLMHead,
|
||||||
VocabParallelEmbedding1D,
|
VocabParallelEmbedding1D,
|
||||||
@ -62,6 +63,8 @@ class MistralPolicy(Policy):
|
|||||||
if self.tie_weight:
|
if self.tie_weight:
|
||||||
embedding_cls = PaddingEmbedding
|
embedding_cls = PaddingEmbedding
|
||||||
|
|
||||||
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -90,6 +93,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -97,6 +101,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -104,6 +109,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -111,6 +117,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -118,6 +125,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -125,6 +133,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -132,6 +141,68 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
elif use_zbv:
|
||||||
|
policy[MistralDecoderLayer] = ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
@ -7,9 +7,18 @@ from torch import Tensor
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
|
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
|
||||||
|
|
||||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
from colossalai.shardformer.layer import (
|
||||||
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
FusedRMSNorm,
|
||||||
from colossalai.shardformer.layer.linear import Linear1D_Row
|
Linear1D_Col,
|
||||||
|
Linear1D_Row,
|
||||||
|
LinearWithGradAccum,
|
||||||
|
PaddingEmbedding,
|
||||||
|
VocabParallelEmbedding1D,
|
||||||
|
)
|
||||||
|
|
||||||
|
# from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||||
|
# from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||||
|
# from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||||
from colossalai.shardformer.modeling.mixtral import (
|
from colossalai.shardformer.modeling.mixtral import (
|
||||||
EPMixtralSparseMoeBlock,
|
EPMixtralSparseMoeBlock,
|
||||||
MixtralPipelineForwards,
|
MixtralPipelineForwards,
|
||||||
@ -166,6 +175,51 @@ class MixtralPolicy(Policy):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif use_zbv:
|
||||||
|
policy[MixtralDecoderLayer] = ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="block_sparse_moe.gate",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
if embedding_cls is not None:
|
if embedding_cls is not None:
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
@ -351,6 +405,22 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
|
elif use_zbv:
|
||||||
|
new_item = {
|
||||||
|
MixtralForCausalLM: ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="lm_head",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs=dict(
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
|
||||||
if self.pipeline_stage_manager:
|
if self.pipeline_stage_manager:
|
||||||
# set None as default
|
# set None as default
|
||||||
|
@ -163,8 +163,6 @@ def main():
|
|||||||
enable_async_reduce=not args.disable_async_reduce,
|
enable_async_reduce=not args.disable_async_reduce,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
use_fp8=args.use_fp8,
|
|
||||||
fp8_communication=args.use_fp8_comm,
|
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
@ -179,8 +177,6 @@ def main():
|
|||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
use_fp8=args.use_fp8,
|
|
||||||
fp8_communication=args.use_fp8_comm,
|
|
||||||
)
|
)
|
||||||
elif args.plugin == "fsdp":
|
elif args.plugin == "fsdp":
|
||||||
if use_empty_init:
|
if use_empty_init:
|
||||||
@ -192,7 +188,6 @@ def main():
|
|||||||
),
|
),
|
||||||
param_init_fn=empty_init(),
|
param_init_fn=empty_init(),
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
fp8_communication=args.use_fp8_comm,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
plugin = TorchFSDPPlugin(
|
plugin = TorchFSDPPlugin(
|
||||||
@ -214,7 +209,6 @@ def main():
|
|||||||
cpu_offload=CPUOffload(offload_params=True),
|
cpu_offload=CPUOffload(offload_params=True),
|
||||||
param_init_fn=empty_init(),
|
param_init_fn=empty_init(),
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
fp8_communication=args.use_fp8_comm,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
plugin = TorchFSDPPlugin(
|
plugin = TorchFSDPPlugin(
|
||||||
@ -225,7 +219,6 @@ def main():
|
|||||||
),
|
),
|
||||||
cpu_offload=CPUOffload(offload_params=True),
|
cpu_offload=CPUOffload(offload_params=True),
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
fp8_communication=args.use_fp8_comm,
|
|
||||||
)
|
)
|
||||||
elif args.plugin == "3d":
|
elif args.plugin == "3d":
|
||||||
if args.pp_style == "zbv":
|
if args.pp_style == "zbv":
|
||||||
|
@ -122,7 +122,7 @@ def main():
|
|||||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||||
),
|
),
|
||||||
"num_layers_per_stage": [19, 20, 20, 21],
|
"num_layers_per_stage": [19, 20, 20, 21],
|
||||||
# "pp_style": "interleaved",
|
"pp_style": "interleaved",
|
||||||
}
|
}
|
||||||
if args.custom_ckpt
|
if args.custom_ckpt
|
||||||
else {}
|
else {}
|
||||||
|
@ -749,24 +749,17 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
|
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
|
||||||
|
|
||||||
|
|
||||||
# TODO:3) support booster & Hybrid base 2)
|
|
||||||
def run_with_hybridplugin(test_config):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# TODO:4) support booster & MoEHybrid base 2)
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
# (0, 1, 4, 1, 1),
|
(1, 2, 1, 1, 2),
|
||||||
# (1, 2, 2, 1, 1),
|
|
||||||
(1, 1, 2, 2, 1),
|
(1, 1, 2, 2, 1),
|
||||||
# (1, 2, 1, 2, 1),
|
(1, 2, 1, 2, 1),
|
||||||
# (1, 2, 1, 1, 2),
|
(1, 2, 2, 1, 1),
|
||||||
|
(1, 1, 4, 1, 1),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
test_config = config
|
|
||||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||||
num_microbatches = pp_size
|
num_microbatches = pp_size
|
||||||
dist.get_world_size()
|
dist.get_world_size()
|
||||||
@ -876,7 +869,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
return_outputs=True,
|
return_outputs=True,
|
||||||
)
|
)
|
||||||
# stage 0 chunk 0
|
# stage 0 chunk 0
|
||||||
parallel_output = None
|
|
||||||
if (
|
if (
|
||||||
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
|
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
|
||||||
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
|
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
|
||||||
@ -910,9 +902,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
p.grad /= dp_size
|
p.grad /= dp_size
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
print(f"rank {dist.get_rank()} config {test_config} test passed")
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -921,11 +911,11 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
@parameterize(
|
@parameterize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
(1, 2, 2, 1), # Pass
|
# Pass
|
||||||
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
|
(1, 2, 2, 1),
|
||||||
# (0, 4, 1, 1),
|
(1, 2, 1, 2),
|
||||||
# (1, 2, 1, 2),
|
(1, 1, 2, 2),
|
||||||
# (1, 1, 2, 2),
|
(1, 4, 1, 1),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
@ -1034,7 +1024,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||||||
return_outputs=True,
|
return_outputs=True,
|
||||||
)
|
)
|
||||||
# stage 0 chunk 0
|
# stage 0 chunk 0
|
||||||
parallel_output = None
|
|
||||||
if (
|
if (
|
||||||
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
|
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
|
||||||
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
|
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
|
||||||
@ -1068,9 +1057,8 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||||||
p.grad /= dp_size
|
p.grad /= dp_size
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
Loading…
Reference in New Issue
Block a user