[feat] Support prompt level dynamic (#6300)

* adjust to dynamic prompt bs

* remove debug

* update pad seq (#6303)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>

* adjust to dynamic prompt bs

* remove debug

* fix dp issue

* fix

* fix default settings

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li
2025-05-14 16:40:35 +08:00
committed by GitHub
parent b920af427b
commit aca547623f
4 changed files with 123 additions and 93 deletions

View File

@@ -144,3 +144,29 @@ def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
return tensor
def all_gather_tensors(local_tensor_list: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
"""
Gathers tensors from all processes and concatenates them along the first dimension.
Args:
tensor (torch.Tensor): The input tensor to be gathered.
Returns:
torch.Tensor: The gathered tensor.
"""
# Gather tensors across DP group
if plugin is not None:
all_tensor_lists = [None] * plugin.dp_size
dist.all_gather_object(all_tensor_lists, local_tensor_list, group=plugin.dp_group)
gathered_tensor_list = []
for tensors in all_tensor_lists:
gathered_tensor_list.extend(tensors)
else:
all_tensor_lists = [None] * dist.get_world_size()
dist.all_gather_object(all_tensor_lists, local_tensor_list)
gathered_tensor_list = []
for tensors in all_tensor_lists:
gathered_tensor_list.extend(tensors)
return gathered_tensor_list