mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[feat] Support prompt level dynamic (#6300)
* adjust to dynamic prompt bs * remove debug * update pad seq (#6303) Co-authored-by: Tong Li <tong.li35271158@gmail.com> * adjust to dynamic prompt bs * remove debug * fix dp issue * fix * fix default settings --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
@@ -144,3 +144,29 @@ def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||
else:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
return tensor
|
||||
|
||||
|
||||
def all_gather_tensors(local_tensor_list: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||
"""
|
||||
Gathers tensors from all processes and concatenates them along the first dimension.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The input tensor to be gathered.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The gathered tensor.
|
||||
"""
|
||||
# Gather tensors across DP group
|
||||
if plugin is not None:
|
||||
all_tensor_lists = [None] * plugin.dp_size
|
||||
dist.all_gather_object(all_tensor_lists, local_tensor_list, group=plugin.dp_group)
|
||||
gathered_tensor_list = []
|
||||
for tensors in all_tensor_lists:
|
||||
gathered_tensor_list.extend(tensors)
|
||||
else:
|
||||
all_tensor_lists = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(all_tensor_lists, local_tensor_list)
|
||||
gathered_tensor_list = []
|
||||
for tensors in all_tensor_lists:
|
||||
gathered_tensor_list.extend(tensors)
|
||||
return gathered_tensor_list
|
||||
|
Reference in New Issue
Block a user