[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -13,7 +13,7 @@ from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator,
from .comm_spec import *
__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options']
__all__ = ["ShapeConsistencyManager", "ShapeConsistencyOptions", "set_shape_consistency_options"]
@dataclass
@@ -21,16 +21,17 @@ class ShapeConsistencyOptions:
"""
ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency.
"""
# TODO: shape consistency option is not implemented yet
pass
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor:
shape_consistency_manager = ShapeConsistencyManager()
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
with torch.no_grad():
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
global_sharding_spec)
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
distributed_tensor, sharding_spec, global_sharding_spec
)
return global_tensor
@@ -43,7 +44,6 @@ def set_shape_consistency_options(options: ShapeConsistencyOptions):
class ShapeConsistencyManager(metaclass=SingletonMeta):
def __init__(self):
self._options = None
self._forward_only = False
@@ -69,9 +69,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
assert isinstance(value, bool)
self._forward_only = value
def get_all_all_gather_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
'''
def get_all_all_gather_spec(
self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]
) -> Dict[ShardingSpec, float]:
"""
Get all valid sharding specs from source_spec with single all-gather operation, and
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
For the all-gather operation, we just care about the S dimension.
@@ -99,7 +100,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,R,R
device_mesh_shape: (4, 4): 0}
'''
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
for target_pair in source_spec.dim_partition_dict.items():
@@ -121,19 +122,20 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
# shard_dim will be used during backward
# shard_dim will be used during backward
shard_dim=gather_dim,
logical_process_axis=logical_process_axis,
forward_only=self.forward_only)
forward_only=self.forward_only,
)
# compute the communication cost with CommSpec
cost_dict = comm_spec.get_comm_cost()
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
new_sharding_spec = ShardingSpec(
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
@@ -141,9 +143,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass
return valid_spec_dict
def get_all_all_to_all_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
'''
def get_all_all_to_all_spec(
self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]
) -> Dict[ShardingSpec, float]:
"""
Get all valid sharding specs from source_spec with single all-to-all operation, and
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
For the all-to-all operation, we just care about the pairs containing S dimension.
@@ -173,7 +176,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,R,S1
device_mesh_shape: (4, 4): 0}
'''
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
tensor_dims = len(source_spec.entire_shape)
@@ -214,12 +217,14 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
gather_dim = b_index
shard_dim = f_index
logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis,
forward_only=self.forward_only)
comm_spec = CommSpec(
comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis,
forward_only=self.forward_only,
)
# compute the communication cost with CommSpec
cost_dict = comm_spec.get_comm_cost()
@@ -238,9 +243,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
new_sharding_spec = ShardingSpec(
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
@@ -250,7 +255,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
return valid_spec_dict
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
'''
"""
Get all valid sharding specs from source_spec with single shard operation, and
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
For the sharding operation, we just care about legal sharding dimensions.
@@ -280,7 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,R,S1
device_mesh_shape: (4, 4): 0}
'''
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
@@ -308,21 +313,23 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
shard_dim = index
logical_process_axis = shard_list[-1]
comm_spec = CommSpec(comm_pattern,
sharding_spec=source_spec,
gather_dim=shard_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis,
forward_only=self.forward_only)
comm_spec = CommSpec(
comm_pattern,
sharding_spec=source_spec,
gather_dim=shard_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis,
forward_only=self.forward_only,
)
# compute the communication cost with CommSpec
cost_dict = comm_spec.get_comm_cost()
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
new_sharding_spec = ShardingSpec(
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
@@ -330,14 +337,15 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass
return valid_spec_dict
def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
'''
def get_all_mix_gather_spec(
self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]
) -> Dict[ShardingSpec, float]:
"""
S0S1 -> RR
S1S0 -> RR
S01R -> RR
RS01 -> RR
'''
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
tensor_dims = len(source_spec.entire_shape)
@@ -362,19 +370,21 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
b_target_pair = (b_index, [])
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
comm_spec = CommSpec(comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=self.forward_only,
mix_gather=True)
comm_spec = CommSpec(
comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=self.forward_only,
mix_gather=True,
)
cost_dict = comm_spec.get_comm_cost()
new_dim_partition_dict = {}
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
new_sharding_spec = ShardingSpec(
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
@@ -384,7 +394,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
'''
"""
Get all valid sharding specs from source_spec with one step transform, and
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
Note:
@@ -398,7 +408,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
'''
"""
valid_spec_dict = {}
valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict))
valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict))
@@ -545,18 +555,22 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input
fwd_action, comm_spec = action_spec_pair
fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
fwd_peak_numel) if idx == 0 else fwd_action(
comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
fwd_alloc_numel, fwd_peak_numel = (
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
if idx == 0
else fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
)
# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0
bwd_peak_numel = 0
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
bwd_action, comm_spec = action_spec_pair
bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
bwd_peak_numel) if idx == 0 else bwd_action(
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
bwd_alloc_numel, bwd_peak_numel = (
bwd_action(comm_spec, False, bwd_alloc_numel, bwd_peak_numel)
if idx == 0
else bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
)
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
@@ -564,9 +578,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
return TrainCycleItem(fwd_mem, bwd_mem, total_mem)
def shape_consistency(self, source_spec: ShardingSpec,
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
'''
def shape_consistency(
self, source_spec: ShardingSpec, target_spec: ShardingSpec
) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
"""
This method will find a path to transform source_spec to target_spec with
a greedy algorithm.
The basic idea is:
@@ -623,9 +638,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0),
CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)]
total_cost: 12294.402000000002
'''
"""
MAX_TRANSFORM_STEPS = 20
total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0}
total_cost_dict = {"forward": 0, "backward": 0, "total": 0}
total_steps = 0
transform_path = []
comm_action_sequence = []
@@ -672,7 +687,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor:
'''
"""
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
shape_consistency method.
@@ -729,7 +744,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
[1.],
[3.],
[3.]])
'''
"""
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
for comm_spec in comm_action_sequence:
tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec)