mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[autoparallel] fix bugs caused by negative dim key (#1808)
* [autoparallel] fix bugs caused by negative dim key * fix import error * fix matmul test issue * fix unit test issue
This commit is contained in:
@@ -1,19 +1,17 @@
|
||||
from . import distspec
|
||||
from .colo_parameter import ColoParameter
|
||||
from .colo_tensor import ColoTensor
|
||||
from .comm_spec import CollectiveCommPattern, CommSpec
|
||||
from .compute_spec import ComputePattern, ComputeSpec
|
||||
from .dist_spec_mgr import DistSpecManager
|
||||
from .distspec import ReplicaSpec, ShardSpec
|
||||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||
from .process_group import ProcessGroup
|
||||
from .tensor_spec import ColoTensorSpec
|
||||
from .distspec import ShardSpec
|
||||
from .distspec import ReplicaSpec
|
||||
|
||||
from .compute_spec import ComputeSpec, ComputePattern
|
||||
from .colo_tensor import ColoTensor
|
||||
from .colo_parameter import ColoParameter
|
||||
from .utils import convert_parameter, named_params_with_colotensor
|
||||
from .dist_spec_mgr import DistSpecManager
|
||||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||
from .comm_spec import CollectiveCommPattern, CommSpec
|
||||
from . import distspec
|
||||
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
|
||||
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern'
|
||||
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list'
|
||||
]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import torch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.const import TensorType
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
|
||||
def filter_args(func, *args):
|
||||
|
||||
@@ -4,9 +4,10 @@ from typing import Callable, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
from .const import TensorType
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
# from colossalai.nn.layer.utils import divide
|
||||
from numpy import prod
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
# from colossalai.nn.layer.utils import divide
|
||||
from numpy import prod
|
||||
from packaging import version
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
|
||||
|
||||
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Any
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
|
||||
class ParamOpHook(ABC):
|
||||
|
||||
@@ -6,6 +6,8 @@ import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .utils import merge_same_dim_mesh_list
|
||||
|
||||
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||
|
||||
ALLGATHER_COST = 20
|
||||
@@ -181,8 +183,12 @@ class ShardingSpec:
|
||||
self.dim_partition_dict = dim_partition_dict
|
||||
self.sharding_sequence = sharding_sequence
|
||||
if self.sharding_sequence is None:
|
||||
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
|
||||
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape),
|
||||
dim_partition_dict=self.dim_partition_dict)
|
||||
self.convert_dict_to_shard_sequence()
|
||||
elif self.dim_partition_dict is None:
|
||||
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
|
||||
self.convert_shard_sequence_to_dict()
|
||||
self._sanity_check()
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from typing import Optional
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from .compute_spec import ComputeSpec
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
|
||||
from .compute_spec import ComputeSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColoTensorSpec:
|
||||
""" ColoTensorSpec
|
||||
|
||||
|
||||
A data class for specifications of the `ColoTensor`.
|
||||
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
|
||||
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import torch
|
||||
from typing import Dict, Iterator, List, Tuple, Union
|
||||
|
||||
from typing import Iterator, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
|
||||
|
||||
@@ -12,7 +13,7 @@ def all_gather_simulator(target_pair):
|
||||
|
||||
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
|
||||
Therefore, all gather operation just remove the last element in shard list,
|
||||
e.g.:
|
||||
e.g.:
|
||||
all-gather(S01) -> S0
|
||||
|
||||
Argument:
|
||||
@@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair):
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
||||
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
||||
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
e.g.:
|
||||
e.g.:
|
||||
all-to-all(S0, S1) -> [S01, R]
|
||||
all-to-all(S0, R) -> [R, S0]
|
||||
Otherwise, we extend the front shard_list to behind.
|
||||
e.g.:
|
||||
e.g.:
|
||||
all-to-all(R, S1) -> [S1, R]
|
||||
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
@@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
|
||||
In addition, We BANNED all representations which shard_list in decreasing order,
|
||||
In addition, We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so shard(S0) -> S10 is NOT allowed.
|
||||
Therefore, for the R dimension, we could just append any legal sharding dim on it.
|
||||
e.g.:
|
||||
@@ -164,3 +165,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
|
||||
|
||||
# Now we can set the attribute appropriately.
|
||||
setattr(module, param_name, st)
|
||||
|
||||
|
||||
def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
|
||||
'''
|
||||
This method is used to convert the negative dim value to positive.
|
||||
'''
|
||||
dims_to_convert = []
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
if dim < 0:
|
||||
dims_to_convert.append(dim)
|
||||
for dim in dims_to_convert:
|
||||
dim_partition_dict.pop(dim)
|
||||
dim_partition_dict[dim_size + dim] = mesh_list
|
||||
return dim_partition_dict
|
||||
|
||||
|
||||
def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
|
||||
'''
|
||||
This method is used to merge the different key value which points to same physical position.
|
||||
|
||||
For example:
|
||||
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
|
||||
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
|
||||
'''
|
||||
converted_dim_partition_dict = {}
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
if dim < 0:
|
||||
dim = dim_size + dim
|
||||
if dim not in converted_dim_partition_dict:
|
||||
converted_dim_partition_dict[dim] = mesh_list
|
||||
else:
|
||||
converted_dim_partition_dict[dim].extend(mesh_list)
|
||||
|
||||
return converted_dim_partition_dict
|
||||
|
||||
Reference in New Issue
Block a user