Files
ColossalAI/colossalai/tensor/d_tensor/layout_converter.py
flybird11111 0c10afd372 [FP8] rebase main (#5963)
* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 16:29:37 +08:00

626 lines
27 KiB
Python

import math
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.misc import LayoutException
from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from .sharding_spec import ShardingSpec
from .utils import get_comm_cost
__all__ = ["LayoutConverter", "LayoutConverterOptions", "set_layout_converting_options"]
@dataclass
class LayoutConverterOptions:
"""
LayoutConverterOptions is a dataclass which specifies the preferences for layout converting.
"""
# TODO: layout converter option is not implemented yet
def set_layout_converting_options(options: LayoutConverterOptions):
"""
Configure the shape consistency manager via function call.
"""
manager = LayoutConverter()
manager.options = options
class LayoutConverter(metaclass=SingletonMeta):
"""
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
"""
def __init__(self):
self._options = None
self._forward_only = False
self.cached_solution = {}
@property
def options(self):
return self._options
@options.setter
def options(self, options_: LayoutConverterOptions):
assert isinstance(options_, LayoutConverterOptions)
self._options = options_
@property
def forward_only(self):
return self._forward_only
@forward_only.setter
def forward_only(self, value):
assert isinstance(value, bool)
self._forward_only = value
def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
"""
Get all valid layouts from source_layout with single all-gather operation.
For the all-gather operation, we just care about the S dimension.
Argument:
source_layout: the layout to be transformed.
Return:
valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single all-gather operation.
Example:
layout_converter = LayoutConverter()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
sharding_spec=sharding_spec,
global_shape=global_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout)
for layout, comm_spec in rst_dict.items():
print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}')
Output:
[R, S1, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:0, shard_dim:0, logical_process_axis:0)
[S0, R, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1)
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
source_spec = source_layout.sharding_spec
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
for target_pair in source_spec.dim_partition_dict.items():
shard_list = all_gather_simulator(target_pair)
index = target_pair[0]
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
# We won't add empty list into dim_partition_dict
# The key will be popped if the related shard_list is empty
if shard_list:
new_dim_partition_dict[index] = shard_list
else:
new_dim_partition_dict.pop(index)
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
gather_dim = index
logical_process_axis = target_pair[1][-1]
comm_spec = CommSpec(
comm_pattern,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
# shard_dim will be used during backward
shard_dim=gather_dim,
logical_process_axis=logical_process_axis,
)
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(
device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
global_shape=source_layout.global_shape,
)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
return valid_spec_dict
def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
"""
Get all valid layouts from source_layout with single all-to-all operation.
For the all-to-all operation, we just care about the pairs containing S dimension.
Argument:
source_layout(Layout): the layout to be transformed.
Return:
valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single all-to-all operation.
Example:
layout_converter = LayoutConverter()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
sharding_spec=sharding_spec,
global_shape=global_shape)
rst_dict = layout_converter.all_to_all_transform_layout(layout)
for layout, comm_spec in rst_dict.items():
print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}')
Output:
[S01, R, R]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:0, logical_process_axis: 1)
[R, S1, S0]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:0, shard_dim:2, logical_process_axis: 0)
[S0, R, S1]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:2, logical_process_axis: 1)
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
source_spec = source_layout.sharding_spec
tensor_dims = source_spec.dims
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
# skip (R, R) cases
if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict:
continue
else:
if f_index in source_spec.dim_partition_dict:
# skip (S01, R) -> (R, S01) is NOT allowed
if len(source_spec.dim_partition_dict[f_index]) >= 2:
continue
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
# skip (R, S01) -> (S01, R) is NOT allowed
if len(source_spec.dim_partition_dict[b_index]) >= 2:
continue
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
else:
b_target_pair = (b_index, [])
# skip (S1, S0) -> S10
if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]:
continue
f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair)
f_index = f_target_pair[0]
b_index = b_target_pair[0]
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
if len(f_shard_list) < len(f_target_pair[1]):
gather_dim = f_index
shard_dim = b_index
logical_process_axis = f_target_pair[1][-1]
else:
gather_dim = b_index
shard_dim = f_index
logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(
comm_pattern,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis,
)
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
# We won't add empty list into dim_partition_dict
# The key will be popped if the related shard_list is empty
if f_shard_list:
new_dim_partition_dict[f_index] = f_shard_list
else:
new_dim_partition_dict.pop(f_index)
if b_shard_list:
new_dim_partition_dict[b_index] = b_shard_list
else:
new_dim_partition_dict.pop(b_index)
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(
device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
global_shape=source_layout.global_shape,
)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
return valid_spec_dict
def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
"""
Get all valid layouts from source_layout with single shard operation.
For the sharding operation, we just care about legal sharding dimensions.
Argument:
source_layout(Layout): the layout to be transformed.
Return:
valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single shard operation.
Example:
layout_converter = LayoutConverter()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0]}
# [S0,R,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
sharding_spec=sharding_spec,
global_shape=global_shape)
rst_dict = layout_converter.shard_transform_layout(layout)
for layout, comm_spec in rst_dict.items():
print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}')
Output:
[S01, R, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:0, shard_dim:0, logical_process_axis:1)
[S0, S1, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1)
[S0, R, S1]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:2, shard_dim:2, logical_process_axis:1)
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
source_spec = source_layout.sharding_spec
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
legal_sharding_dims.remove(element)
if len(legal_sharding_dims) == 0:
return valid_spec_dict
tensor_dims = source_spec.dims
for index in range(tensor_dims):
if index not in source_spec.dim_partition_dict:
shard_list_list = shard_simulator((index, []), legal_sharding_dims)
else:
shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims)
if not shard_list_list:
continue
for shard_list in shard_list_list:
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
new_dim_partition_dict[index] = shard_list
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
shard_dim = index
logical_process_axis = shard_list[-1]
comm_spec = CommSpec(
comm_pattern,
process_group_dict=process_group_dict,
gather_dim=shard_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis,
)
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(
dim_size=source_spec.dims, dim_partition_dict=new_dim_partition_dict
)
new_layout = Layout(
device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
global_shape=source_layout.global_shape,
)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
"""
Get all valid layouts from source_layout with one step transform.
Note:
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
we could safely put them together.
Argument:
source_layout(Layout): the layout to be transformer.
Return:
valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with one step transform.
"""
valid_spec_dict = {}
valid_spec_dict.update(self.all_gather_transform_layouts(source_layout))
valid_spec_dict.update(self.all_to_all_transform_layout(source_layout))
valid_spec_dict.update(self.shard_transform_layout(source_layout))
return valid_spec_dict
def layout_converting(
self, source_layout: Layout, target_layout: Layout
) -> Tuple[List[Layout], List[CommSpec], float]:
"""
This method will find a path to transform source_layout to target_layout with
a greedy algorithm.
The basic idea is:
Step1:
Generate all one-step transform sequences from source_layout.
Step2:
Pick the 'best' layout following the heuristic function.
Step3:
Repeat above steps until the source layout transform to target layout.
Additionally, to avoid repeating the path search in runtime, we cached all solved path
in auto parallel strategy building time, which could handle most of cases in runtime.
Args:
source_layout(Layout): the layout to be transformed.
target_layout(Layout): the layout to be achieved after a serious of transforms.
Return:
transform_path(List[Layout]): The transform path from source_layout to target_layout,
it contains the source_layout and target_layout.
comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the layout converting in order.
Example:
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
global_shape = (4, 4, 4)
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
# [R,S01,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
sharding_spec=sharding_spec_source,
global_shape=global_shape)
# [S01,R,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
sharding_spec=sharding_spec_target,
global_shape=global_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
print(transform_path_str)
output:
[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]
"""
source_spec = source_layout.sharding_spec
target_spec = target_layout.sharding_spec
MAX_TRANSFORM_STEPS = 20
total_steps = 0
transform_path = []
comm_action_sequence: List[CommSpec] = []
src_shape = source_layout.get_sharded_shape_per_device()
dst_shape = target_layout.get_sharded_shape_per_device()
spec_pairs = ((str(source_spec.sharding_sequence), src_shape), (str(target_spec.sharding_sequence), dst_shape))
if spec_pairs in self.cached_solution:
# Solution Cache hit
def _group_alive_check(cached_comm_action_sequence):
r"""
Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method.
If not deleted, return True; otherwise, return False.
Args:
cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions.
Returns:
bool: True if all process groups are still registered, False if at least one has been deleted.
Raises:
RuntimeError: If there is an error while checking the status of a process group.
"""
# Collect all process groups used in communication actions from the cached sequence
used_process_groups = [
pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values()
]
# Check if each process group is still alive
for process_group in used_process_groups:
try:
dist.get_rank(process_group)
except (ValueError, RuntimeError) as e:
# If the group is not registered, it means it has been deleted
if str(e) == (
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"
):
return False
elif str(e) == "The given group does not exist":
return False
else:
# Re-raise the exception if it's not related to group deletion
raise e
# All process groups are alive
return True
cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs]
if _group_alive_check(cached_comm_action_sequence):
# If all process groups have not been deleted, the cache is valid
return cached_transform_path, cached_comm_action_sequence
else:
# If at least one process group has been deleted, the cache is invalid, so delete it
del self.cached_solution[spec_pairs]
# We do nothing if the sharding spec is all the same.
if source_spec.spec_diff(target_spec) == 0:
self.cached_solution[spec_pairs] = (transform_path, comm_action_sequence)
return (
transform_path,
comm_action_sequence,
)
temp_sharding_layout = source_layout
transform_path.append(temp_sharding_layout)
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
while total_steps <= MAX_TRANSFORM_STEPS:
valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_layout)
best_difference_score = math.inf
for layout, comm_spec in valid_transform_spec_dict.items():
sharding_spec = layout.sharding_spec
spec_difference = sharding_spec.spec_diff(target_spec)
if spec_difference == 0:
transform_path.append(layout)
comm_action_sequence.append(comm_spec)
self.cached_solution[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence)
if spec_difference < best_difference_score:
temp_sharding_layout = layout
temp_comm_spec = comm_spec
best_difference_score = spec_difference
transform_path.append(temp_sharding_layout)
comm_action_sequence.append(temp_comm_spec)
total_steps += 1
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> Dict[str, float]:
"""
Get the total communication cost of the layout converting process.
"""
transform_path, comm_action_sequence = self.layout_converting(source_layout, target_layout)
total_cost = {"forward": 0.0, "backward": 0.0, "total": 0.0}
for layout, comm_spec in zip(transform_path, comm_action_sequence):
cost_dict = get_comm_cost(layout, comm_spec, self.forward_only)
for key in total_cost:
total_cost[key] += cost_dict[key]
return total_cost
def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layout) -> torch.Tensor:
"""
Apply target_layout to tensor with source layout, the transform path is generated by the
layout_converting method.
Argument:
tensor (torch.Tensor): The tensor to be redistributed.
source_layout(Layout): The source layout of the tensor.
target_layout (Layout): The tensor will be redistributed to the target_layout.
Example:
layout_converter = LayoutConverter()
dim_partition_source = {0: [0]}
dim_partition_target = {1: [0]}
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
global_shape = (4, 4, 4)
# [S0,R,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
sharding_spec=sharding_spec_source,
global_shape=global_shape)
# [R,S0,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
sharding_spec=sharding_spec_target,
global_shape=global_shape)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
sharded_tensor_1 = torch.ones(2, 1)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (2, 3):
sharded_tensor_0 = torch.ones(2, 1) * 2
sharded_tensor_1 = torch.ones(2, 1) * 3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
# converted_tensor: [R, S0, R]
converted_tensor = layout_converter.apply(tensor_to_comm, source_layout, target_layout)
print(converted_tensor)
Output in rank0 and rank1:
tensor([[0.],
[0.],
[2.],
[2.]])
Output in rank2 and rank3:
tensor([[1.],
[1.],
[3.],
[3.]])
"""
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
target_tensor = tensor
for comm_spec in comm_action_sequence:
target_tensor = comm_spec.covert_spec_to_action(target_tensor)
target_tensor.dist_layout = target_layout
# restore the padding information
if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor):
target_tensor = init_as_padded_tensor(
target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
)
return target_tensor