[Doc] add more doc for ColoTensor. (#1458)

This commit is contained in:
Jiarui Fang
2022-08-16 10:38:41 +08:00
committed by GitHub
parent a1476ea882
commit 36824a304c
4 changed files with 46 additions and 18 deletions

View File

@@ -1,7 +1,7 @@
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .distspec import shard as ShardSpec
from .distspec import replicate as ReplicaSpec
from .distspec import ShardSpec
from .distspec import ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
@@ -13,6 +13,6 @@ from . import distspec
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
'ShardSpec', 'ReplicaSpec'
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
'ReplicaSpec'
]

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import List
__all__ = ['replicate', 'shard']
__all__ = ['ReplicaSpec', 'ShardSpec']
class DistPlacementPattern(Enum):
@@ -10,15 +10,22 @@ class DistPlacementPattern(Enum):
class _DistSpec:
"""_DistSpec
A class indicates Distributed Specification.
The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced.
This is an internal data structrue.
The API for users should be `ShardSpec` and `ReplicaSpec`.
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
"""_DistSpec, Distributed Specification
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
self.placement = dist_placement_pattern
for k, v in meta_info.items():
setattr(self, k, v)
@@ -39,11 +46,32 @@ class _DistSpec:
return ''.join(res_list)
def replicate() -> _DistSpec:
def ReplicaSpec() -> _DistSpec:
"""ReplicaSpec
A distributed specification represents the tensor is replicated among the tensor parallel process group.
Returns:
_DistSpec: an replicated dist spec instance.
"""
return _DistSpec(DistPlacementPattern.REPLICATE)
def shard(dims: List[int], num_partitions: List[int]) -> _DistSpec:
def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec:
"""ShardSpec
A distributed specification represents the tensor is sharded among the tensor parallel process group.
Note:
Currently, only shard on one dimension is valid. In another word, dims should be of size 1.
Args:
dims (List[int]): a list of dimensions
num_partitions (List[int]): a list of partition number of each dimensions.
Returns:
_DistSpec: an shard dist spec instance.
"""
assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions)
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))