[autoparallel] mix gather (#1977)

* Add mix-gather

* Add comments

* Add comments

* Polish comments

* Change the global rank assumption

* Add tests

* Add two-step tests

* Fix 10 and 01

* Skip test becasue the number of GPUs
This commit is contained in:
Genghan Zhang
2022-11-23 21:49:17 +08:00
committed by GitHub
parent 7242bffc5f
commit d655eea515
5 changed files with 617 additions and 4 deletions

View File

@@ -90,6 +90,31 @@ def shard_simulator(target_pair, legal_sharding_dims):
return shard_list_list
def mix_gather_simulator(f_target_pair, b_target_pair):
'''
Assume index of f and b target pairs are 'f' and 'b'
S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0)
S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1)
S01R => Input: (f, [0, 1]), (b, []) Output: [f], (1, 1)
RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1)
S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0)
RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0)
'''
if f_target_pair[1] and b_target_pair[1]:
leading_dim = b_target_pair[1] > f_target_pair[1]
return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)]
if f_target_pair[1]:
leading_dim = f_target_pair[1][0] < f_target_pair[1][1]
return [
f_target_pair[0],
], [int(leading_dim), int(leading_dim)]
if b_target_pair[1]:
leading_dim = b_target_pair[1][0] < b_target_pair[1][1]
return [
b_target_pair[0],
], [int(leading_dim), int(leading_dim)]
# The function is credited to PyTorch Team
def named_params_with_colotensor(
module: nn.Module,