mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
@@ -161,7 +161,7 @@ class ProcessGroupMesh:
|
||||
|
||||
@staticmethod
|
||||
def get_coords_along_axis(
|
||||
base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
|
||||
base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]]
|
||||
) -> List[Tuple[int, ...]]:
|
||||
"""Get coordinates along the given axis.
|
||||
|
||||
@@ -173,13 +173,28 @@ class ProcessGroupMesh:
|
||||
Returns:
|
||||
List[Tuple[int, ...]]: Coordinates along the axis.
|
||||
"""
|
||||
coords_in_group = []
|
||||
for idx in indices_at_axis:
|
||||
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
|
||||
if isinstance(axis, int):
|
||||
axis = [axis,]
|
||||
assert isinstance(indices_at_axis[0], int)
|
||||
indices_at_axis = [indices_at_axis,]
|
||||
|
||||
def add_index(base_coord, axis, indices_at_axis):
|
||||
coords_in_group = []
|
||||
for idx in indices_at_axis:
|
||||
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
|
||||
return coords_in_group
|
||||
|
||||
coords_in_group = [base_coord]
|
||||
for ax, indices_at_ax in zip(axis, indices_at_axis):
|
||||
new_coords_in_group = []
|
||||
for coords in coords_in_group:
|
||||
new_coords_in_group += add_index(coords, ax, indices_at_ax)
|
||||
coords_in_group = new_coords_in_group
|
||||
|
||||
return coords_in_group
|
||||
|
||||
def create_group_along_axis(
|
||||
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
||||
self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None
|
||||
) -> ProcessGroup:
|
||||
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
||||
|
||||
@@ -191,10 +206,17 @@ class ProcessGroupMesh:
|
||||
Returns:
|
||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||
"""
|
||||
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
|
||||
if isinstance(axis, int):
|
||||
axis = [axis,]
|
||||
if indices_at_axis is not None:
|
||||
assert isinstance(indices_at_axis[0], int)
|
||||
indices_at_axis = [indices_at_axis,]
|
||||
|
||||
indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
|
||||
reduced_shape = list(self._shape)
|
||||
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
||||
reduced_shape[axis] = 1
|
||||
for ax in axis:
|
||||
reduced_shape[ax] = 1
|
||||
target_group = None
|
||||
# use Cartesian product to generate all combinations of coordinates
|
||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||
@@ -225,4 +247,3 @@ class ProcessGroupMesh:
|
||||
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
||||
return self._ranks_to_group[ranks_in_group]
|
||||
|
Reference in New Issue
Block a user