mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[Doc] add more doc for ColoTensor. (#1458)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from .process_group import ProcessGroup
|
||||
from .tensor_spec import ColoTensorSpec
|
||||
from .distspec import shard as ShardSpec
|
||||
from .distspec import replicate as ReplicaSpec
|
||||
from .distspec import ShardSpec
|
||||
from .distspec import ReplicaSpec
|
||||
|
||||
from .compute_spec import ComputeSpec, ComputePattern
|
||||
from .colo_tensor import ColoTensor
|
||||
@@ -13,6 +13,6 @@ from . import distspec
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
|
||||
'ShardSpec', 'ReplicaSpec'
|
||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
|
||||
'ReplicaSpec'
|
||||
]
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
__all__ = ['replicate', 'shard']
|
||||
__all__ = ['ReplicaSpec', 'ShardSpec']
|
||||
|
||||
|
||||
class DistPlacementPattern(Enum):
|
||||
@@ -10,15 +10,22 @@ class DistPlacementPattern(Enum):
|
||||
|
||||
|
||||
class _DistSpec:
|
||||
"""_DistSpec
|
||||
|
||||
A class indicates Distributed Specification.
|
||||
The DistSpec is only works for the tensor parallel process groups.
|
||||
Because the dist spec of data parallel process group can be automatically deduced.
|
||||
This is an internal data structrue.
|
||||
The API for users should be `ShardSpec` and `ReplicaSpec`.
|
||||
|
||||
Args:
|
||||
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
|
||||
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
|
||||
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
|
||||
"""_DistSpec, Distributed Specification
|
||||
|
||||
Args:
|
||||
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
|
||||
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
|
||||
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
|
||||
"""
|
||||
self.placement = dist_placement_pattern
|
||||
for k, v in meta_info.items():
|
||||
setattr(self, k, v)
|
||||
@@ -39,11 +46,32 @@ class _DistSpec:
|
||||
return ''.join(res_list)
|
||||
|
||||
|
||||
def replicate() -> _DistSpec:
|
||||
def ReplicaSpec() -> _DistSpec:
|
||||
"""ReplicaSpec
|
||||
|
||||
A distributed specification represents the tensor is replicated among the tensor parallel process group.
|
||||
|
||||
Returns:
|
||||
_DistSpec: an replicated dist spec instance.
|
||||
"""
|
||||
return _DistSpec(DistPlacementPattern.REPLICATE)
|
||||
|
||||
|
||||
def shard(dims: List[int], num_partitions: List[int]) -> _DistSpec:
|
||||
def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec:
|
||||
"""ShardSpec
|
||||
|
||||
A distributed specification represents the tensor is sharded among the tensor parallel process group.
|
||||
|
||||
Note:
|
||||
Currently, only shard on one dimension is valid. In another word, dims should be of size 1.
|
||||
|
||||
Args:
|
||||
dims (List[int]): a list of dimensions
|
||||
num_partitions (List[int]): a list of partition number of each dimensions.
|
||||
|
||||
Returns:
|
||||
_DistSpec: an shard dist spec instance.
|
||||
"""
|
||||
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
||||
assert len(dims) == len(num_partitions)
|
||||
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
||||
|
Reference in New Issue
Block a user