mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
[autoparallel] mix gather (#1977)
* Add mix-gather * Add comments * Add comments * Polish comments * Change the global rank assumption * Add tests * Add two-step tests * Fix 10 and 01 * Skip test becasue the number of GPUs
This commit is contained in:
@@ -90,6 +90,31 @@ def shard_simulator(target_pair, legal_sharding_dims):
|
||||
return shard_list_list
|
||||
|
||||
|
||||
def mix_gather_simulator(f_target_pair, b_target_pair):
|
||||
'''
|
||||
Assume index of f and b target pairs are 'f' and 'b'
|
||||
S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0)
|
||||
S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1)
|
||||
S01R => Input: (f, [0, 1]), (b, []) Output: [f], (1, 1)
|
||||
RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1)
|
||||
S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0)
|
||||
RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0)
|
||||
'''
|
||||
if f_target_pair[1] and b_target_pair[1]:
|
||||
leading_dim = b_target_pair[1] > f_target_pair[1]
|
||||
return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)]
|
||||
if f_target_pair[1]:
|
||||
leading_dim = f_target_pair[1][0] < f_target_pair[1][1]
|
||||
return [
|
||||
f_target_pair[0],
|
||||
], [int(leading_dim), int(leading_dim)]
|
||||
if b_target_pair[1]:
|
||||
leading_dim = b_target_pair[1][0] < b_target_pair[1][1]
|
||||
return [
|
||||
b_target_pair[0],
|
||||
], [int(leading_dim), int(leading_dim)]
|
||||
|
||||
|
||||
# The function is credited to PyTorch Team
|
||||
def named_params_with_colotensor(
|
||||
module: nn.Module,
|
||||
|
Reference in New Issue
Block a user