[moe] support optimizer checkpoint (#5015)

* Refactor MoE Manager setup method

* unshard optim ckpt

* optim io

* update transformer version

* update requirements

* update ckpt

* update ckpt

* update ckpt

* fix engine

* fix engine
This commit is contained in:
Xuanlei Zhao
2023-11-08 23:07:03 +08:00
committed by GitHub
parent 67f5331754
commit f71e63b0f3
20 changed files with 738 additions and 150 deletions

View File

@@ -72,6 +72,19 @@ def get_ep_size(tensor: torch.Tensor) -> int:
return tensor.moe_info.ep_size
def get_dp_size(tensor: torch.Tensor) -> int:
"""
Get the data parallel size of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
int: The data parallel size of the given tensor.
"""
return tensor.moe_info.dp_size
def get_dp_group(tensor: torch.Tensor) -> ProcessGroup:
"""
Get the data parallel group of the given tensor.