[shardformer] support interleaved pipeline (#4448)

* support interleaved pipeline

* fix unit test

* remove virtual stage test in stage mgr

* add droped type hint and updated bwd
This commit is contained in:
LuGY
2023-08-16 19:29:03 +08:00
committed by GitHub
parent 26e29d58f0
commit a78daf6180
7 changed files with 642 additions and 109 deletions

View File

@@ -94,17 +94,23 @@ class ProcessGroupMesh:
return np.unravel_index(rank, shape)
@staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int:
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
"""Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
Args:
coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
mode (Optional[str]): The mode for numpy.ravel_multi_index.
Returns:
int: Rank of the coordinate.
"""
return np.ravel_multi_index(coord, shape)
assert mode in ["raise", "wrap", "clip"]
return np.ravel_multi_index(coord, shape, mode)
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.