mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
[tensor] shape consistency generate transform path and communication cost (#1435)
* [tensor] shape consistency output transform path and communication cost * polish code
This commit is contained in:
@@ -5,6 +5,90 @@ import torch.nn as nn
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
|
||||
|
||||
def all_gather_simulator(target_pair):
|
||||
'''
|
||||
Simulating all-gather operation, analyze the communication cost
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
|
||||
Therefore, all gather operation just remove the last element in shard list,
|
||||
e.g.:
|
||||
all-gather(S01) -> S0
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
_, shard_list = target_pair
|
||||
new_shard_list = shard_list[:-1]
|
||||
|
||||
return new_shard_list
|
||||
|
||||
|
||||
def all_to_all_simulator(f_target_pair, b_target_pair):
|
||||
'''
|
||||
Simulating all-to-all operation, analyze the communication cost
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
||||
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
e.g.:
|
||||
all-to-all(S0, S1) -> [S01, R]
|
||||
all-to-all(S0, R) -> [R, S0]
|
||||
Otherwise, we extend the front shard_list to behind.
|
||||
e.g.:
|
||||
all-to-all(R, S1) -> [S1, R]
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
_, f_shard_list = f_target_pair
|
||||
_, b_shard_list = b_target_pair
|
||||
if not len(b_shard_list):
|
||||
b_shard_list.extend(f_shard_list)
|
||||
f_shard_list = []
|
||||
else:
|
||||
f_shard_list.extend(b_shard_list)
|
||||
b_shard_list = []
|
||||
|
||||
return f_shard_list, b_shard_list
|
||||
|
||||
|
||||
def shard_simulator(target_pair, legal_sharding_dims):
|
||||
'''
|
||||
Simulating shard operation, analyze the communication cost(always ZERO)
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
|
||||
In addition, We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so shard(S0) -> S10 is NOT allowed.
|
||||
Therefore, for the R dimension, we could just append any legal sharding dim on it.
|
||||
e.g.:
|
||||
shard(R) -> S0
|
||||
For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
|
||||
e.g:
|
||||
shard(S0) -> S01
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
_, shard_list = target_pair
|
||||
shard_list_list = []
|
||||
for dim in legal_sharding_dims:
|
||||
if len(shard_list) != 0 and dim <= shard_list[-1]:
|
||||
continue
|
||||
new_shard_list = shard_list + [dim]
|
||||
shard_list_list.append(new_shard_list)
|
||||
|
||||
return shard_list_list
|
||||
|
||||
|
||||
# The function is credited to PyTorch Team
|
||||
def named_params_with_colotensor(
|
||||
module: nn.Module,
|
||||
|
Reference in New Issue
Block a user