mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
[autoparallel] fix bugs caused by negative dim key (#1808)
* [autoparallel] fix bugs caused by negative dim key * fix import error * fix matmul test issue * fix unit test issue
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import torch
|
||||
from typing import Dict, Iterator, List, Tuple, Union
|
||||
|
||||
from typing import Iterator, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
|
||||
|
||||
@@ -12,7 +13,7 @@ def all_gather_simulator(target_pair):
|
||||
|
||||
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
|
||||
Therefore, all gather operation just remove the last element in shard list,
|
||||
e.g.:
|
||||
e.g.:
|
||||
all-gather(S01) -> S0
|
||||
|
||||
Argument:
|
||||
@@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair):
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
||||
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
||||
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
e.g.:
|
||||
e.g.:
|
||||
all-to-all(S0, S1) -> [S01, R]
|
||||
all-to-all(S0, R) -> [R, S0]
|
||||
Otherwise, we extend the front shard_list to behind.
|
||||
e.g.:
|
||||
e.g.:
|
||||
all-to-all(R, S1) -> [S1, R]
|
||||
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
@@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
|
||||
In addition, We BANNED all representations which shard_list in decreasing order,
|
||||
In addition, We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so shard(S0) -> S10 is NOT allowed.
|
||||
Therefore, for the R dimension, we could just append any legal sharding dim on it.
|
||||
e.g.:
|
||||
@@ -164,3 +165,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
|
||||
|
||||
# Now we can set the attribute appropriately.
|
||||
setattr(module, param_name, st)
|
||||
|
||||
|
||||
def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
|
||||
'''
|
||||
This method is used to convert the negative dim value to positive.
|
||||
'''
|
||||
dims_to_convert = []
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
if dim < 0:
|
||||
dims_to_convert.append(dim)
|
||||
for dim in dims_to_convert:
|
||||
dim_partition_dict.pop(dim)
|
||||
dim_partition_dict[dim_size + dim] = mesh_list
|
||||
return dim_partition_dict
|
||||
|
||||
|
||||
def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
|
||||
'''
|
||||
This method is used to merge the different key value which points to same physical position.
|
||||
|
||||
For example:
|
||||
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
|
||||
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
|
||||
'''
|
||||
converted_dim_partition_dict = {}
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
if dim < 0:
|
||||
dim = dim_size + dim
|
||||
if dim not in converted_dim_partition_dict:
|
||||
converted_dim_partition_dict[dim] = mesh_list
|
||||
else:
|
||||
converted_dim_partition_dict[dim].extend(mesh_list)
|
||||
|
||||
return converted_dim_partition_dict
|
||||
|
Reference in New Issue
Block a user