ColossalAI/colossalai/tensor/sharding_spec.py
flybird11111 0c10afd372
[FP8] rebase main (#5963)
* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 16:29:37 +08:00

319 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import operator
from functools import reduce
import torch
from colossalai.device.device_mesh import DeviceMesh
from .utils import merge_same_dim_mesh_list
__all__ = ["_DimSpec", "ShardingException", "ShardingSpec"]
ALLGATHER_COST = 20
SHARD_COST = 5
STEP_PENALTY = 6
NAN = "nan"
class _DimSpec:
"""
Sharding spec for single dimension of the sharded tensor describe the sharding dimension of
logical device mesh and give a method to compute the difference between them.
This class is used internally in ShardingSpec.
Argument:
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
Otherwise, the element in shard_list means the data will be sharded in that dimension.
"""
_DIFFERENCE_DICT = None
def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0
self.shard_list = shard_list
def __eq__(self, other):
return str(self) == str(other)
def __repr__(self):
if self.is_replica:
return "R"
target = "S"
for dim in self.shard_list:
target += str(dim)
return target
@property
def difference_dict(self):
"""
Returns the difference dict, and lazily initializes it when needed
Return:
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
difference dict
"""
if self._DIFFERENCE_DICT is None:
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
return self._DIFFERENCE_DICT
def difference(self, other):
"""
The difference between two _DimSpec.
Argument:
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
@classmethod
def _build_difference_2d_dict(cls):
"""
Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between _DimSpec pairs.
"""
source_spec_list = ["R", "S0", "S1", "S01"]
target_spec_list = ["R", "S0", "S1", "S01"]
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
source_shard_list = cls._convert_str_to_shard_list(source_spec)
target_shard_list = cls._convert_str_to_shard_list(target_spec)
# source same as target
if source_shard_list == target_shard_list:
difference = 0
# all_gather(source) -> target
elif (
len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list
):
difference = ALLGATHER_COST
# shard(source) -> target
elif (
len(source_shard_list) == len(target_shard_list) - 1
and source_shard_list == target_shard_list[:-1]
and target_shard_list[-1] not in source_shard_list
):
difference = SHARD_COST
# S1 -> S0 or S0 -> S1
elif len(source_shard_list) == len(target_shard_list):
# source -> R -> target
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST
# R -> S01
elif len(source_shard_list) == len(target_shard_list) - 2:
difference = SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> R
elif len(source_shard_list) == len(target_shard_list) + 2:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
# S1 -> S01
elif len(source_shard_list) == len(target_shard_list) - 1:
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> S1
elif len(source_shard_list) == len(target_shard_list) + 1:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST
else:
difference = NAN
difference_dict[(source_spec, target_spec)] = difference
return difference_dict
@staticmethod
def _convert_str_to_shard_list(str_spec):
"""
Convert str_spec into shard_list.
Argument:
str_spec(str): dim spec in str type.
"""
if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
class ShardingSpecException(Exception):
pass
class ShardingOutOfIndexError(ShardingSpecException):
pass
class DuplicatedShardingDimensionError(ShardingSpecException):
pass
class ShardingNotDivisibleError(ShardingSpecException):
pass
class ShardingSpec:
"""
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
[R, R, S0, S1].
Argument:
device_mesh(DeviceMesh): A logical view of a physical mesh.
entire_shape(torch.Size): The entire shape of tensor before sharded.
dim_partition_dict(Dict[int, List[int]] optional): The key is the dimension of tensor to be sharded,
and the value of the key describe which logical axis will be sharded in that dimension.
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
"""
def __init__(
self, device_mesh: DeviceMesh, entire_shape: torch.Size, dim_partition_dict=None, sharding_sequence=None
):
self.device_mesh = device_mesh
if isinstance(entire_shape, (list, tuple)):
entire_shape = torch.Size(entire_shape)
self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence
if self.sharding_sequence is None:
assert (
self.dim_partition_dict is not None
), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object."
self.dim_partition_dict = merge_same_dim_mesh_list(
dim_size=len(entire_shape), dim_partition_dict=self.dim_partition_dict
)
self.convert_dict_to_shard_sequence()
elif self.dim_partition_dict is None:
assert (
self.sharding_sequence is not None
), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object."
self.convert_shard_sequence_to_dict()
self._sanity_check()
def __repr__(self):
res_list = ["DistSpec:"]
res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}")
return " ".join(res_list)
def _sanity_check(self):
# make sure all axes in logical device mesh only be used once
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in self.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}."
)
# make sure that the dimension is not out of index
for dim in self.dim_partition_dict.keys():
if dim >= len(self.entire_shape):
raise ShardingOutOfIndexError(
f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
)
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in self.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim]
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(
f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices."
)
def convert_dict_to_shard_sequence(self):
"""
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
"""
sharding_sequence = [_DimSpec([])] * len(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
sharding_sequence[dim] = _DimSpec(shard_list)
self.sharding_sequence = sharding_sequence
def convert_shard_sequence_to_dict(self):
"""
Convert sharding_sequence into dim_partition_dict.
"""
new_dim_partition_dict = {}
for index, dim_spec in enumerate(self.sharding_sequence):
if not dim_spec.is_replica:
if index not in new_dim_partition_dict:
new_dim_partition_dict[index] = []
new_dim_partition_dict[index].extend(dim_spec.shard_list)
self.dim_partition_dict = new_dim_partition_dict
def sharding_sequence_difference(self, other):
"""
This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
pair of sharding sequence.
Example:
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
dim_partition_dict_to_compare = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
Argument:
other(ShardingSpec): The ShardingSpec to compared with.
Return:
difference(int): Difference between two ShardingSpec.
"""
assert len(self.sharding_sequence) == len(
other.sharding_sequence
), f"Cannot compare difference for two sharding specs with different length."
difference = 0
for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence):
difference += orig_dim_spec.difference(other_dim_spec)
return difference
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert (
sharded_shape[dim] % shard_partitions == 0
), f"Cannot shard dimension {dim} into {shard_partitions} partitions."
sharded_shape[dim] //= shard_partitions
return torch.Size(sharded_shape)