ColossalAI/colossalai/tensor/distspec.py
digger-yu b9a8dff7e5
[doc] Fix typo under colossalai and doc(#3618)
* Fixed several spelling errors under colossalai

* Fix the spelling error in colossalai and docs directory

* Cautious Changed the spelling error under the example folder

* Update runtime_preparation_pass.py

revert autograft to autograd

* Update search_chunk.py

utile to until

* Update check_installation.py

change misteach to mismatch in line 91

* Update 1D_tensor_parallel.md

revert to perceptron

* Update 2D_tensor_parallel.md

revert to perceptron in line 73

* Update 2p5D_tensor_parallel.md

revert to perceptron in line 71

* Update 3D_tensor_parallel.md

revert to perceptron in line 80

* Update README.md

revert to resnet in line 42

* Update reorder_graph.py

revert to indice in line 7

* Update p2p.py

revert to megatron in line 94

* Update initialize.py

revert to torchrun in line 198

* Update routers.py

change to detailed in line 63

* Update routers.py

change to detailed in line 146

* Update README.md

revert  random number in line 402
2023-04-26 11:38:43 +08:00

79 lines
2.7 KiB
Python

from enum import Enum
from typing import List
__all__ = ['ReplicaSpec', 'ShardSpec']
class DistPlacementPattern(Enum):
REPLICATE = 'r'
SHARD = 's'
class _DistSpec:
"""_DistSpec
A class indicates Distributed Specification.
The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced.
This is an internal data structure.
The API for users should be `ShardSpec` and `ReplicaSpec`.
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
self.placement = dist_placement_pattern
for k, v in meta_info.items():
setattr(self, k, v)
def __eq__(self, other: "_DistSpec") -> bool:
if dir(self) != dir(other):
return False
for attr in dir(self):
if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr):
return False
return True
def __repr__(self) -> str:
attr_list = []
for attr in dir(self):
if not attr.startswith('__'):
attr_list.append(f'{attr}={str(getattr(self, attr))}')
attr_str = ", ".join(attr_list)
return "DistSpec(" + attr_str + ")"
def ReplicaSpec() -> _DistSpec:
"""ReplicaSpec
A distributed specification represents the tensor is replicated among the tensor parallel process group.
Returns:
_DistSpec: an replicated dist spec instance.
"""
return _DistSpec(DistPlacementPattern.REPLICATE)
def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec:
"""ShardSpec
A distributed specification represents the tensor is sharded among the tensor parallel process group.
Note:
Currently, only shard on one dimension is valid. In another word, dims should be of size 1.
Args:
dims (List[int]): a list of dimensions
num_partitions (List[int]): a list of partition number of each dimensions.
Returns:
_DistSpec: an shard dist spec instance.
"""
assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions)
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))