mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[autoparallel] handled illegal sharding strategy in shape consistency (#1744)
* [autoparallel] handled illegal sharding strategy in shape consistency * polish code
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
from enum import Enum
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
import torch.distributed as dist
|
||||
import math
|
||||
from functools import reduce
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
|
||||
from .comm_spec import *
|
||||
|
||||
__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options']
|
||||
@@ -62,10 +65,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
|
||||
def get_all_all_gather_spec(self, source_spec, orig_cost_dict):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the all-gather operation, we just care about the S dimension.
|
||||
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
@@ -82,12 +85,12 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
|
||||
print(rst_dict)
|
||||
|
||||
|
||||
Output:
|
||||
{DistSpec:
|
||||
shard_sequence: R,S1,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,R
|
||||
{DistSpec:
|
||||
shard_sequence: R,S1,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,R
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
@@ -120,20 +123,23 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
cost_dict = comm_spec.get_comm_cost()
|
||||
|
||||
# generate new sharding spec
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-to-all operation, and
|
||||
Get all valid sharding specs from source_spec with single all-to-all operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the all-to-all operation, we just care about the pairs containing S dimension.
|
||||
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
@@ -150,14 +156,14 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0)
|
||||
print(rst_dict)
|
||||
|
||||
|
||||
Output:
|
||||
{DistSpec:
|
||||
shard_sequence: S01,R,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: R,S1,S0
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,S1
|
||||
{DistSpec:
|
||||
shard_sequence: S01,R,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: R,S1,S0
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,S1
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
@@ -223,20 +229,24 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
new_dim_partition_dict.pop(b_index)
|
||||
|
||||
# generate new sharding spec
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_shard_spec(self, source_spec, orig_cost_dict):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single shard operation, and
|
||||
Get all valid sharding specs from source_spec with single shard operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the sharding operation, we just care about legal sharding dimensions.
|
||||
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
@@ -253,14 +263,14 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0)
|
||||
print(rst_dict)
|
||||
|
||||
|
||||
Output:
|
||||
{DistSpec:
|
||||
shard_sequence: S01,R,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,S1,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,S1
|
||||
{DistSpec:
|
||||
shard_sequence: S01,R,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,S1,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,S1
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
@@ -275,6 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
return valid_spec_dict
|
||||
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
|
||||
for index in range(tensor_dims):
|
||||
if index not in source_spec.dim_partition_dict:
|
||||
shard_list_list = shard_simulator((index, []), legal_sharding_dims)
|
||||
@@ -300,23 +311,26 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
cost_dict = comm_spec.get_comm_cost()
|
||||
|
||||
# generate new sharding spec
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with one step transform, and
|
||||
Get all valid sharding specs from source_spec with one step transform, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
Note:
|
||||
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
|
||||
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
|
||||
we could safely put them together.
|
||||
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
@@ -343,7 +357,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
Repeat above steps until the source spec transform to target spec.
|
||||
|
||||
During finding the transform path, commucation cost will be accumulated, and it
|
||||
will be finally used in auto parallel solver.
|
||||
will be finally used in auto parallel solver.
|
||||
|
||||
Additionally, to avoid repeating the path search in runtime, we cached all solved path
|
||||
in auto parallel strategy building time, which could handle most of cases in runtime.
|
||||
@@ -361,30 +375,30 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
Example:
|
||||
dim_partition_source = {1: [0, 1]}
|
||||
dim_partition_target = {0: [0, 1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: R,S01,R
|
||||
# DistSpec:
|
||||
# shard_sequence: R,S01,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
|
||||
transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target)
|
||||
print(f'transform_path: {transform_path}')
|
||||
print(f'comm_action_sequence: {comm_action_sequence}')
|
||||
print(f'total_cost: {total_cost}')
|
||||
|
||||
|
||||
output:
|
||||
transform_path: [DistSpec:
|
||||
shard_sequence: R,S01,R
|
||||
device_mesh_shape: (4, 4), DistSpec:
|
||||
shard_sequence: R,S0,R
|
||||
device_mesh_shape: (4, 4), DistSpec:
|
||||
shard_sequence: S0,R,R
|
||||
device_mesh_shape: (4, 4), DistSpec:
|
||||
shard_sequence: S01,R,R
|
||||
transform_path: [DistSpec:
|
||||
shard_sequence: R,S01,R
|
||||
device_mesh_shape: (4, 4), DistSpec:
|
||||
shard_sequence: R,S0,R
|
||||
device_mesh_shape: (4, 4), DistSpec:
|
||||
shard_sequence: S0,R,R
|
||||
device_mesh_shape: (4, 4), DistSpec:
|
||||
shard_sequence: S01,R,R
|
||||
device_mesh_shape: (4, 4)]
|
||||
comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1),
|
||||
comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1),
|
||||
CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0),
|
||||
CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)]
|
||||
total_cost: 12294.402000000002
|
||||
@@ -403,6 +417,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
return (transform_path, comm_action_sequence, total_cost_dict)
|
||||
|
||||
temp_sharding_spec = source_spec
|
||||
|
||||
transform_path.append(temp_sharding_spec)
|
||||
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
|
||||
while total_steps <= MAX_TRANSFORM_STEPS:
|
||||
@@ -437,13 +452,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
|
||||
def apply(self, tensor_with_sharding_spec, target_spec):
|
||||
'''
|
||||
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
|
||||
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
|
||||
shape_consistency method.
|
||||
|
||||
|
||||
Argument:
|
||||
tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec.
|
||||
target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec.
|
||||
|
||||
|
||||
Example:
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
@@ -459,7 +474,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
# shard_sequence: S0,R
|
||||
# device_mesh_shape: (2, 2)
|
||||
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
|
||||
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: R,S0
|
||||
# device_mesh_shape: (2, 2)
|
||||
@@ -481,13 +496,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
tensor_to_comm.sharding_spec = sharding_spec_source
|
||||
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
|
||||
print(tensor_to_comm)
|
||||
|
||||
|
||||
Output in rank0 and rank2:
|
||||
tensor([[0.],
|
||||
[0.],
|
||||
[2.],
|
||||
[2.]])
|
||||
|
||||
|
||||
Output in rank1 and rank3:
|
||||
tensor([[1.],
|
||||
[1.],
|
||||
@@ -505,4 +520,4 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
for comm_spec in comm_action_sequence:
|
||||
comm_spec.covert_spec_to_action(tensor)
|
||||
tensor.sharding_spec = target_spec
|
||||
return tensor
|
||||
return tensor
|
||||
|
Reference in New Issue
Block a user