mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -7,7 +7,7 @@ from colossalai.tensor.colo_tensor import ColoTensor
|
||||
|
||||
|
||||
def all_gather_simulator(target_pair):
|
||||
'''
|
||||
"""
|
||||
Simulating all-gather operation, analyze the communication cost
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
@@ -19,7 +19,7 @@ def all_gather_simulator(target_pair):
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element describes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
"""
|
||||
_, shard_list = target_pair
|
||||
new_shard_list = shard_list[:-1]
|
||||
|
||||
@@ -27,7 +27,7 @@ def all_gather_simulator(target_pair):
|
||||
|
||||
|
||||
def all_to_all_simulator(f_target_pair, b_target_pair):
|
||||
'''
|
||||
"""
|
||||
Simulating all-to-all operation, analyze the communication cost
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
@@ -47,7 +47,7 @@ def all_to_all_simulator(f_target_pair, b_target_pair):
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element describes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
"""
|
||||
_, f_shard_list = f_target_pair
|
||||
_, b_shard_list = b_target_pair
|
||||
if not len(b_shard_list):
|
||||
@@ -61,7 +61,7 @@ def all_to_all_simulator(f_target_pair, b_target_pair):
|
||||
|
||||
|
||||
def shard_simulator(target_pair, legal_sharding_dims):
|
||||
'''
|
||||
"""
|
||||
Simulating shard operation, analyze the communication cost(always ZERO)
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
@@ -78,7 +78,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element describes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
"""
|
||||
_, shard_list = target_pair
|
||||
shard_list_list = []
|
||||
for dim in legal_sharding_dims:
|
||||
@@ -91,7 +91,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
|
||||
|
||||
|
||||
def mix_gather_simulator(f_target_pair, b_target_pair):
|
||||
'''
|
||||
"""
|
||||
Assume index of f and b target pairs are 'f' and 'b'
|
||||
S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0)
|
||||
S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1)
|
||||
@@ -99,7 +99,7 @@ def mix_gather_simulator(f_target_pair, b_target_pair):
|
||||
RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1)
|
||||
S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0)
|
||||
RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0)
|
||||
'''
|
||||
"""
|
||||
if f_target_pair[1] and b_target_pair[1]:
|
||||
leading_dim = b_target_pair[1] > f_target_pair[1]
|
||||
return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)]
|
||||
@@ -118,7 +118,7 @@ def mix_gather_simulator(f_target_pair, b_target_pair):
|
||||
# The function is credited to PyTorch Team
|
||||
def named_params_with_colotensor(
|
||||
module: nn.Module,
|
||||
prefix: str = '',
|
||||
prefix: str = "",
|
||||
recurse: bool = True,
|
||||
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
|
||||
r"""Returns an iterator over module parameters (together with the
|
||||
@@ -154,7 +154,7 @@ def named_params_with_colotensor(
|
||||
for name, val in vars(mod).items():
|
||||
if isinstance(val, ColoTensor) and val not in memo:
|
||||
memo.add(val)
|
||||
name = mod_prefix + ('.' if mod_prefix else '') + name
|
||||
name = mod_prefix + ("." if mod_prefix else "") + name
|
||||
yield name, val
|
||||
|
||||
# find all nn.Parameters
|
||||
@@ -169,15 +169,16 @@ def _convert_tensor(tensor: torch.Tensor) -> ColoTensor:
|
||||
def convert_parameter(module: torch.nn.Module, param_name: str):
|
||||
# Perform some validation first.
|
||||
if not hasattr(module, param_name):
|
||||
raise ValueError(f'module: {module} does not have parameter with name: {param_name}')
|
||||
raise ValueError(f"module: {module} does not have parameter with name: {param_name}")
|
||||
|
||||
tensor = getattr(module, param_name)
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise ValueError(
|
||||
f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
|
||||
f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}"
|
||||
)
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
raise ValueError(f'param: {param_name} is not a contiguous Tensor')
|
||||
raise ValueError(f"param: {param_name} is not a contiguous Tensor")
|
||||
|
||||
st = _convert_tensor(tensor)
|
||||
|
||||
@@ -193,9 +194,9 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
|
||||
|
||||
|
||||
def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
|
||||
'''
|
||||
"""
|
||||
This method is used to convert the negative dim value to positive.
|
||||
'''
|
||||
"""
|
||||
dims_to_convert = []
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
if dim < 0:
|
||||
@@ -207,13 +208,13 @@ def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List
|
||||
|
||||
|
||||
def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
|
||||
'''
|
||||
"""
|
||||
This method is used to merge the different key value which points to same physical position.
|
||||
|
||||
For example:
|
||||
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
|
||||
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
|
||||
'''
|
||||
"""
|
||||
converted_dim_partition_dict = {}
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
if dim < 0:
|
||||
|
Reference in New Issue
Block a user