ColossalAI/colossalai/cluster/process_group_mesh.py
Haze188 416580b314
[MoE/ZeRO] Moe refactor with zero refactor (#5821)
* [moe] removed openmoe-coupled code and rectify mixstral code (#5471)

* [Feauture] MoE refractor; Intergration with Mixtral  (#5682)

* cherry pick from refractor-moe branch

* tests passed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support ep + zero

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* add mixtral auto policy & move pipeline forward code to modeling folder

* [moe refactor] modify kernel test without Route Class

* [moe refactor] add moe tensor test path environment variable to github workflow

* fix typos

* fix moe test bug due to the code rebase

* [moe refactor] fix moe zero test, and little bug in low level zero

* fix typo

* add moe tensor path to github workflow

* remove some useless code

* fix typo & unify global variable XX_AXIS logic without using -1

* fix typo & prettifier the code

* remove print code & support zero 2 test

* remove useless code

* reanme function

* fix typo

* fix typo

* Further improve the test code

* remove print code

* [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test

* [moe refactor] skip some unit test which will be refactored later

* [moe refactor] fix unit import error

* [moe refactor] fix circular import issues

* [moe refactor] remove debug code

* [moe refactor] update github workflow

* [moe/zero] refactor low level optimizer (#5767)

* [zero] refactor low level optimizer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] MoE refactor with newest version of ZeRO (#5801)

* [zero] remove redundant members in BucketStore (#5802)

* [zero] align api with previous version

* [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819)

* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* [hotfix]Solve the compatibility issue of zero refactor (#5823)

* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* Modify function parameter names to resolve compatibility issues

* [zero] fix missing hook removal (#5824)

* [MoE] Resolve .github conflict (#5829)

* [Fix/Example] Fix Llama Inference Loading Data Type (#5763)

* [fix/example] fix llama inference loading dtype

* revise loading dtype of benchmark llama3

* [release] update version (#5752)

* [release] update version

* [devops] update compatibility test

* [devops] update compatibility test

* [devops] update compatibility test

* [devops] update compatibility test

* [test] fix ddp plugin test

* [test] fix gptj and rpc test

* [devops] fix cuda ext compatibility

* [inference] fix flash decoding test

* [inference] fix flash decoding test

* fix (#5765)

* [test] Fix/fix testcase (#5770)

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;

* [Hotfix] Add missing init file in inference.executor (#5774)

* [CI/tests] simplify some test case to reduce testing time (#5755)

* [ci/tests] simplify some test case to reduce testing time

* [ci/tests] continue to remove test case to reduce ci time cost

* restore some test config

* [ci/tests] continue to reduce ci time cost

* [misc] update dockerfile (#5776)

* [misc] update dockerfile

* [misc] update dockerfile

* [devops] fix docker ci (#5780)

* [Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist

* [hotfix] fix llama flash attention forward (#5777)

* [misc] Accelerate CI for zero and dist optim (#5758)

* remove fp16 from lamb

* remove d2h copy in checking states

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Test/CI] remove test cases to reduce CI duration (#5753)

* [test] smaller gpt2 test case

* [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py

* [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py

* [test] reduce test cases tests/test_zero/test_gemini/test_optim.py

* Revert "[test] smaller gpt2 test case"

Some tests might depend on the size of model (num of chunks)

This reverts commit df705a5210.

* [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py

* [CI] smaller test model for two mwo the two modifid cases

* [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there

* [hotfix] fix testcase in test_fx/test_tracer (#5779)

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;

* [fix] fix test_deepfm_model & test_dlrf_model;

* [fix] fix test_hf_albert & test_hf_gpt;

* [gemini] optimize reduce scatter d2h copy (#5760)

* [gemini] optimize reduce scatter d2h copy

* [fix] fix missing reduce variable

* [refactor] remove legacy async reduce scatter code

* [gemini] missing sync

* Revert "[refactor] remove legacy async reduce scatter code"

This reverts commit 58ad76d466.

* [gemini] further optimize with async all reduce

* [fix] pass flag from manager to chunk

* Allow building cuda extension without a device. (#5535)

Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are.

* [misc] fix dist logger (#5782)

* [install]fix setup (#5786)

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update requirements (#5787)

* [shardformer] fix import (#5788)

* upgrade colossal-chat support tp_group>1, add sp for sft

* upgrade ppo dpo rm script

* run pre-commit

* moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy

* fix training script

* fix ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix transformers version

* remove duplicated test

* fix datasets version

* remove models that require huggingface auth from ci

* remove local data path

* update ci

* remove baichuan from template test due to transformer version conflict

* merge

* Refactor modeling by adding attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix tests and naming

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Clean up

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* replace the customized dataloader setup with the build-in one

* replace the customized dataloader setup with the build-in one

* Remove flash attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* fix readme

* Fix test import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* update sft trainning script

* [Inference]refactor baichuan (#5791)

* refactor baichuan

* remove unused code and add TODO for lazyinit

* [test] fix chatglm test kit (#5793)

* [shardformer] fix modeling of bloom and falcon (#5796)

* [test] fix qwen2 pytest distLarge (#5797)

* [Inference] Fix flash-attn import and add model test (#5794)

* Fix torch int32 dtype

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix flash-attn import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add generalized model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Remove exposed path to model

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add default value for use_flash_attn

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Rename model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* [Gemini] Use async stream to prefetch and h2d data moving (#5781)

* use async stream to prefetch and h2d data moving

* Remove redundant code

* [gemini] quick fix on possible async operation (#5803)

* [gemini] quick fix on possible async operation

* [gemini] quick fix on possible async operation

* [shardformer] upgrade transformers to 4.39.3 (#5815)

* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807)

* [shardformer] fix modeling of gpt2 and gptj

* [shardformer] fix whisper modeling

* [misc] update requirements

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* [shardformer]upgrade transformers for mistral (#5808)

* upgrade transformers for mistral

* fix

* fix

* [shardformer]upgrade transformers for llama (#5809)

* update transformers

fix

* fix

* fix

* [inference] upgrade transformers (#5810)

* update transformers

fix

* fix

* fix

* fix

* fix

* [gemini] update transformers for gemini (#5814)

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* Support 4d parallel + flash attention (#5789)

* support tp + sp + pp

* remove comments

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>

* [zero] fix hook bug

* [zero] add low level optimizer back (#5839)

* [zero] fix param & refactor

* [zero] add back original low level opt

* [zero] remove moe related

* [zero] pass zero tests

* [zero] refactor

* [chore] add del func back

* [zero] comments and naming (#5840)

* [zero] modify api (#5843)

* [zero] modify api

* [test] remove _grad_store access in tests

* [test] fix (#5857)

* [CI] skip openmoe CI check

* [CI] fox pre-commit

* [zero] remove redundant memebr init (#5862)

* [misc] remove useless code, modify the pg mesh implementation

* [misc] remove useless code, modify the pg mesh implementation

* [misc] use tempfile

* resolve conflict with main branch

* [misc] use tempfile in test_moe_checkpoint.py

* [misc] remove useless code, add assertion about sequence parallel, move logger into function

* [misc] remove useless code

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
2024-06-28 14:00:08 +08:00

272 lines
11 KiB
Python

import gc
import itertools
from functools import reduce
from operator import mul
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup
def prod(nums: List[int]) -> int:
"""Product of a list of numbers.
Args:
nums (List[int]): A list of numbers.
Returns:
int: The product of the numbers.
"""
return reduce(mul, nums)
class ProcessGroupMesh:
"""A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method.
It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation.
We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process.
For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``.
Args:
*size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size.
Attributes:
shape (Tuple[int, ...]): The shape of the process group mesh.
rank (int): The rank of the current process.
"""
def __init__(self, *size: int) -> None:
assert dist.is_initialized(), "Please initialize torch.distributed first."
world_size = dist.get_world_size()
prod_size = prod(size)
assert (
prod_size == world_size
), f"The product of the size({prod_size}) must be equal to the world size({world_size})."
self._shape = size
self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
def destroy_mesh_process_groups(self):
r"""
Destructor method for the ProcessGroupMesh class.
When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
cleaning up any process groups that were created during the lifetime of the object.
Note:
All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
system resources.
"""
for group in self._ranks_to_group.values():
dist.destroy_process_group(group)
# Manually clear all process groups to save memory
gc.collect()
@property
def shape(self) -> Tuple[int, ...]:
return self._shape
@property
def rank(self) -> int:
return self._rank
def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
"""Get the size of the process group mesh.
Args:
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
Returns:
Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh.
"""
if dim is None:
return self._shape
else:
return self._shape[dim]
def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
"""Get the coordinate of the process group mesh.
Args:
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
Returns:
Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh.
"""
if dim is None:
return self._coord
else:
return self._coord[dim]
@staticmethod
def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
"""Convert a rank to a coordinate.
Args:
rank (int): Rank to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
Returns:
Tuple[int, ...]: Coordinate of the rank.
"""
return np.unravel_index(rank, shape)
@staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int:
"""Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
Args:
coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
mode (Optional[str]): The mode for numpy.ravel_multi_index.
Returns:
int: Rank of the coordinate.
"""
assert mode in ["raise", "wrap", "clip"]
return int(np.ravel_multi_index(coord, shape, mode))
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
Args:
ranks_in_group (List[int]): Ranks in the process group.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group with the given ranks.
"""
ranks_in_group = sorted(ranks_in_group)
if tuple(ranks_in_group) not in self._group_to_ranks:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group)
return self._ranks_to_group[tuple(ranks_in_group)]
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
"""Get the ranks in the given process group. The process group must be created by this class.
Args:
group (ProcessGroup): The process group.
Returns:
List[int]: Ranks in the process group.
"""
return list(self._group_to_ranks[group])
@staticmethod
def get_coords_along_axis(
base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]]
) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis.
Args:
base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on.
axis (int): Axis along which the coordinates are generated.
indices_at_axis (List[int]): Indices at the axis.
Returns:
List[Tuple[int, ...]]: Coordinates along the axis.
"""
if isinstance(axis, int):
axis = [
axis,
]
assert isinstance(indices_at_axis[0], int), f"Expected int, but got {type(indices_at_axis[0])}."
indices_at_axis = [
indices_at_axis,
]
def add_index(base_coord, axis, indices_at_axis):
coords_in_group = []
for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group
coords_in_group = [base_coord]
for ax, indices_at_ax in zip(axis, indices_at_axis):
new_coords_in_group = []
for coords in coords_in_group:
new_coords_in_group += add_index(coords, ax, indices_at_ax)
coords_in_group = new_coords_in_group
return coords_in_group
def create_group_along_axis(
self,
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
if isinstance(axis, int):
axis = [
axis,
]
if indices_at_axis is not None:
assert isinstance(indices_at_axis[0], int)
indices_at_axis = [
indices_at_axis,
]
indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
reduced_shape = list(self._shape)
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
for ax in axis:
reduced_shape[ax] = 1
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
group = self.get_group(ranks_in_group, backend=backend)
if self._rank in ranks_in_group:
target_group = group
return target_group
def get_group_along_axis(
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
axis (int or list of int): Axes along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
indices_at_axis = indices_at_axis
if indices_at_axis is None:
if isinstance(axis, (list, tuple)):
indices_at_axis = list(list(range(self._shape[ax])) for ax in axis)
else:
indices_at_axis = list(range(self._shape[axis]))
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if ranks_in_group not in self._ranks_to_group:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
return self._ranks_to_group[ranks_in_group]