mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -13,7 +13,7 @@ from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator,
|
||||
|
||||
from .comm_spec import *
|
||||
|
||||
__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options']
|
||||
__all__ = ["ShapeConsistencyManager", "ShapeConsistencyOptions", "set_shape_consistency_options"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -21,16 +21,17 @@ class ShapeConsistencyOptions:
|
||||
"""
|
||||
ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency.
|
||||
"""
|
||||
|
||||
# TODO: shape consistency option is not implemented yet
|
||||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
|
||||
with torch.no_grad():
|
||||
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
|
||||
global_sharding_spec)
|
||||
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
distributed_tensor, sharding_spec, global_sharding_spec
|
||||
)
|
||||
return global_tensor
|
||||
|
||||
|
||||
@@ -43,7 +44,6 @@ def set_shape_consistency_options(options: ShapeConsistencyOptions):
|
||||
|
||||
|
||||
class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self):
|
||||
self._options = None
|
||||
self._forward_only = False
|
||||
@@ -69,9 +69,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
assert isinstance(value, bool)
|
||||
self._forward_only = value
|
||||
|
||||
def get_all_all_gather_spec(self, source_spec: ShardingSpec,
|
||||
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
def get_all_all_gather_spec(
|
||||
self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]
|
||||
) -> Dict[ShardingSpec, float]:
|
||||
"""
|
||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the all-gather operation, we just care about the S dimension.
|
||||
@@ -99,7 +100,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,R
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
"""
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
for target_pair in source_spec.dim_partition_dict.items():
|
||||
@@ -121,19 +122,20 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
# shard_dim will be used during backward
|
||||
# shard_dim will be used during backward
|
||||
shard_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axis,
|
||||
forward_only=self.forward_only)
|
||||
forward_only=self.forward_only,
|
||||
)
|
||||
|
||||
# compute the communication cost with CommSpec
|
||||
cost_dict = comm_spec.get_comm_cost()
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
new_sharding_spec = ShardingSpec(
|
||||
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
|
||||
)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
@@ -141,9 +143,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_all_to_all_spec(self, source_spec: ShardingSpec,
|
||||
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
def get_all_all_to_all_spec(
|
||||
self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]
|
||||
) -> Dict[ShardingSpec, float]:
|
||||
"""
|
||||
Get all valid sharding specs from source_spec with single all-to-all operation, and
|
||||
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the all-to-all operation, we just care about the pairs containing S dimension.
|
||||
@@ -173,7 +176,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,S1
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
"""
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
@@ -214,12 +217,14 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
gather_dim = b_index
|
||||
shard_dim = f_index
|
||||
logical_process_axis = b_target_pair[1][-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis,
|
||||
forward_only=self.forward_only)
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis,
|
||||
forward_only=self.forward_only,
|
||||
)
|
||||
|
||||
# compute the communication cost with CommSpec
|
||||
cost_dict = comm_spec.get_comm_cost()
|
||||
@@ -238,9 +243,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
new_sharding_spec = ShardingSpec(
|
||||
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
|
||||
)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
@@ -250,7 +255,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
|
||||
'''
|
||||
"""
|
||||
Get all valid sharding specs from source_spec with single shard operation, and
|
||||
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the sharding operation, we just care about legal sharding dimensions.
|
||||
@@ -280,7 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,S1
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
"""
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
|
||||
@@ -308,21 +313,23 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
||||
shard_dim = index
|
||||
logical_process_axis = shard_list[-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=shard_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis,
|
||||
forward_only=self.forward_only)
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=shard_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis,
|
||||
forward_only=self.forward_only,
|
||||
)
|
||||
|
||||
# compute the communication cost with CommSpec
|
||||
cost_dict = comm_spec.get_comm_cost()
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
new_sharding_spec = ShardingSpec(
|
||||
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
|
||||
)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
@@ -330,14 +337,15 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
|
||||
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
def get_all_mix_gather_spec(
|
||||
self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]
|
||||
) -> Dict[ShardingSpec, float]:
|
||||
"""
|
||||
S0S1 -> RR
|
||||
S1S0 -> RR
|
||||
S01R -> RR
|
||||
RS01 -> RR
|
||||
'''
|
||||
"""
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
@@ -362,19 +370,21 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
b_target_pair = (b_index, [])
|
||||
|
||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=self.forward_only,
|
||||
mix_gather=True)
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=self.forward_only,
|
||||
mix_gather=True,
|
||||
)
|
||||
cost_dict = comm_spec.get_comm_cost()
|
||||
new_dim_partition_dict = {}
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
new_sharding_spec = ShardingSpec(
|
||||
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
|
||||
)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
@@ -384,7 +394,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
"""
|
||||
Get all valid sharding specs from source_spec with one step transform, and
|
||||
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
|
||||
Note:
|
||||
@@ -398,7 +408,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
||||
'''
|
||||
"""
|
||||
valid_spec_dict = {}
|
||||
valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict))
|
||||
valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict))
|
||||
@@ -545,18 +555,22 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||
# the first forward comm action will not discard input
|
||||
fwd_action, comm_spec = action_spec_pair
|
||||
fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
|
||||
fwd_peak_numel) if idx == 0 else fwd_action(
|
||||
comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
||||
fwd_alloc_numel, fwd_peak_numel = (
|
||||
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
|
||||
if idx == 0
|
||||
else fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
||||
)
|
||||
|
||||
# analyze memory footprint for backward comm actions sequence
|
||||
bwd_alloc_numel = 0
|
||||
bwd_peak_numel = 0
|
||||
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||
bwd_action, comm_spec = action_spec_pair
|
||||
bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
|
||||
bwd_peak_numel) if idx == 0 else bwd_action(
|
||||
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||
bwd_alloc_numel, bwd_peak_numel = (
|
||||
bwd_action(comm_spec, False, bwd_alloc_numel, bwd_peak_numel)
|
||||
if idx == 0
|
||||
else bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||
)
|
||||
|
||||
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
||||
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
|
||||
@@ -564,9 +578,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
|
||||
return TrainCycleItem(fwd_mem, bwd_mem, total_mem)
|
||||
|
||||
def shape_consistency(self, source_spec: ShardingSpec,
|
||||
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
|
||||
'''
|
||||
def shape_consistency(
|
||||
self, source_spec: ShardingSpec, target_spec: ShardingSpec
|
||||
) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
|
||||
"""
|
||||
This method will find a path to transform source_spec to target_spec with
|
||||
a greedy algorithm.
|
||||
The basic idea is:
|
||||
@@ -623,9 +638,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0),
|
||||
CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)]
|
||||
total_cost: 12294.402000000002
|
||||
'''
|
||||
"""
|
||||
MAX_TRANSFORM_STEPS = 20
|
||||
total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0}
|
||||
total_cost_dict = {"forward": 0, "backward": 0, "total": 0}
|
||||
total_steps = 0
|
||||
transform_path = []
|
||||
comm_action_sequence = []
|
||||
@@ -672,7 +687,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
|
||||
|
||||
def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor:
|
||||
'''
|
||||
"""
|
||||
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
|
||||
shape_consistency method.
|
||||
|
||||
@@ -729,7 +744,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
[1.],
|
||||
[3.],
|
||||
[3.]])
|
||||
'''
|
||||
"""
|
||||
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
|
||||
for comm_spec in comm_action_sequence:
|
||||
tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
|
||||
|
Reference in New Issue
Block a user