diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index 4a1b8ab26..d2bad06bb 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import operator from colossalai.tensor import ProcessGroup -from colossalai.tensor.distspec import shard +from colossalai.tensor.distspec import ShardSpec from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] @@ -85,13 +85,13 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc for shard_type, module in annotation_record.items(): # add row sharding spec if shard_type == 'row': - dist_spec = shard(dims=[-1], num_partitions=[world_size]) + dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size]) comp_spec = ComputeSpec(ComputePattern.TP1D) setattr(module.weight, 'pg', process_group) setattr(module.weight, 'dist_spec', dist_spec) setattr(module.weight, 'comp_spec', comp_spec) elif shard_type == 'col': - weight_dist_spec = shard(dims=[0], num_partitions=[world_size]) + weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) weight_comp_spec = ComputeSpec(ComputePattern.TP1D) weight_comp_spec.output_replicate = False setattr(module.weight, 'pg', process_group) @@ -99,7 +99,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc setattr(module.weight, 'comp_spec', weight_comp_spec) if module.bias is not None: - bias_dist_spec = shard(dims=[0], num_partitions=[world_size]) + bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) bias_comp_spec = ComputeSpec(ComputePattern.TP1D) bias_comp_spec.output_replicate = False setattr(module.bias, 'pg', process_group) diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 3ad15a436..bf0e4bf34 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -1,7 +1,7 @@ from .process_group import ProcessGroup from .tensor_spec import ColoTensorSpec -from .distspec import shard as ShardSpec -from .distspec import replicate as ReplicaSpec +from .distspec import ShardSpec +from .distspec import ReplicaSpec from .compute_spec import ComputeSpec, ComputePattern from .colo_tensor import ColoTensor @@ -13,6 +13,6 @@ from . import distspec __all__ = [ 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', - 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', - 'ShardSpec', 'ReplicaSpec' + 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec', + 'ReplicaSpec' ] diff --git a/colossalai/tensor/distspec.py b/colossalai/tensor/distspec.py index 4796d420c..0b62cbdda 100644 --- a/colossalai/tensor/distspec.py +++ b/colossalai/tensor/distspec.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List -__all__ = ['replicate', 'shard'] +__all__ = ['ReplicaSpec', 'ShardSpec'] class DistPlacementPattern(Enum): @@ -10,15 +10,22 @@ class DistPlacementPattern(Enum): class _DistSpec: + """_DistSpec + + A class indicates Distributed Specification. + The DistSpec is only works for the tensor parallel process groups. + Because the dist spec of data parallel process group can be automatically deduced. + This is an internal data structrue. + The API for users should be `ShardSpec` and `ReplicaSpec`. + + Args: + dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes. + The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard. + process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None. + """ def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info): - """_DistSpec, Distributed Specification - Args: - dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes. - The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard. - process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None. - """ self.placement = dist_placement_pattern for k, v in meta_info.items(): setattr(self, k, v) @@ -39,11 +46,32 @@ class _DistSpec: return ''.join(res_list) -def replicate() -> _DistSpec: +def ReplicaSpec() -> _DistSpec: + """ReplicaSpec + + A distributed specification represents the tensor is replicated among the tensor parallel process group. + + Returns: + _DistSpec: an replicated dist spec instance. + """ return _DistSpec(DistPlacementPattern.REPLICATE) -def shard(dims: List[int], num_partitions: List[int]) -> _DistSpec: +def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec: + """ShardSpec + + A distributed specification represents the tensor is sharded among the tensor parallel process group. + + Note: + Currently, only shard on one dimension is valid. In another word, dims should be of size 1. + + Args: + dims (List[int]): a list of dimensions + num_partitions (List[int]): a list of partition number of each dimensions. + + Returns: + _DistSpec: an shard dist spec instance. + """ assert isinstance(dims, list) and isinstance(num_partitions, list) assert len(dims) == len(num_partitions) return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions)) diff --git a/tests/test_utils/test_norm_gradient_clipping.py b/tests/test_utils/test_norm_gradient_clipping.py index 7690dbb38..259286663 100644 --- a/tests/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_utils/test_norm_gradient_clipping.py @@ -19,7 +19,7 @@ def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): def shard_param(p: ColoParameter) -> None: pg = p.get_process_group() - p._redistribute(distspec.shard([0], [pg.tp_world_size()])) + p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()])) p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()