[hotfix]: modify create_ep_hierarchical_group and add test (#5032)

* feat: modify create_ep_hierarchical_group args

* test: add ep tests

* fix: remove get_process_group_ranks

* fix: fix src_rank
This commit is contained in:
Wenhao Chen
2023-11-17 10:53:00 +08:00
committed by GitHub
parent 97cd0cd559
commit 3c08f17348
5 changed files with 38 additions and 28 deletions

View File

@@ -1,3 +1,5 @@
from typing import List
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@@ -124,7 +126,7 @@ def get_dp_rank(tensor: torch.Tensor) -> int:
return dist.get_rank(get_dp_group(tensor))
def get_ep_group_ranks(tensor: torch.Tensor) -> int:
def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]:
"""
Get the expert parallel group ranks of the given tensor.
@@ -137,7 +139,7 @@ def get_ep_group_ranks(tensor: torch.Tensor) -> int:
return tensor.moe_info.ep_group_ranks
def get_dp_group_ranks(tensor: torch.Tensor) -> int:
def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]:
"""
Get the data parallel group ranks of the given tensor.