mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[devops] remove post commit ci (#5566)
* [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -174,16 +174,20 @@ class ProcessGroupMesh:
|
||||
List[Tuple[int, ...]]: Coordinates along the axis.
|
||||
"""
|
||||
if isinstance(axis, int):
|
||||
axis = [axis,]
|
||||
axis = [
|
||||
axis,
|
||||
]
|
||||
assert isinstance(indices_at_axis[0], int)
|
||||
indices_at_axis = [indices_at_axis,]
|
||||
indices_at_axis = [
|
||||
indices_at_axis,
|
||||
]
|
||||
|
||||
def add_index(base_coord, axis, indices_at_axis):
|
||||
coords_in_group = []
|
||||
for idx in indices_at_axis:
|
||||
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
|
||||
return coords_in_group
|
||||
|
||||
|
||||
coords_in_group = [base_coord]
|
||||
for ax, indices_at_ax in zip(axis, indices_at_axis):
|
||||
new_coords_in_group = []
|
||||
@@ -194,7 +198,10 @@ class ProcessGroupMesh:
|
||||
return coords_in_group
|
||||
|
||||
def create_group_along_axis(
|
||||
self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None
|
||||
self,
|
||||
axis: Union[int, List[int]],
|
||||
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
backend: Optional[str] = None,
|
||||
) -> ProcessGroup:
|
||||
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
||||
|
||||
@@ -207,11 +214,15 @@ class ProcessGroupMesh:
|
||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||
"""
|
||||
if isinstance(axis, int):
|
||||
axis = [axis,]
|
||||
axis = [
|
||||
axis,
|
||||
]
|
||||
if indices_at_axis is not None:
|
||||
assert isinstance(indices_at_axis[0], int)
|
||||
indices_at_axis = [indices_at_axis,]
|
||||
|
||||
indices_at_axis = [
|
||||
indices_at_axis,
|
||||
]
|
||||
|
||||
indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
|
||||
reduced_shape = list(self._shape)
|
||||
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
||||
|
Reference in New Issue
Block a user