mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[hotfix]: modify create_ep_hierarchical_group and add test (#5032)
* feat: modify create_ep_hierarchical_group args * test: add ep tests * fix: remove get_process_group_ranks * fix: fix src_rank
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -124,7 +126,7 @@ def get_dp_rank(tensor: torch.Tensor) -> int:
|
||||
return dist.get_rank(get_dp_group(tensor))
|
||||
|
||||
|
||||
def get_ep_group_ranks(tensor: torch.Tensor) -> int:
|
||||
def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]:
|
||||
"""
|
||||
Get the expert parallel group ranks of the given tensor.
|
||||
|
||||
@@ -137,7 +139,7 @@ def get_ep_group_ranks(tensor: torch.Tensor) -> int:
|
||||
return tensor.moe_info.ep_group_ranks
|
||||
|
||||
|
||||
def get_dp_group_ranks(tensor: torch.Tensor) -> int:
|
||||
def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]:
|
||||
"""
|
||||
Get the data parallel group ranks of the given tensor.
|
||||
|
||||
|
Reference in New Issue
Block a user