mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,8 +14,11 @@ from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = [
|
||||
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'comm_actions_for_oprands'
|
||||
"BroadcastType",
|
||||
"is_broadcastable",
|
||||
"get_broadcast_shape",
|
||||
"recover_sharding_spec_for_broadcast_shape",
|
||||
"comm_actions_for_oprands",
|
||||
]
|
||||
|
||||
|
||||
@@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
|
||||
"""
|
||||
Compute the broadcast shape given two shapes.
|
||||
"""
|
||||
assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
|
||||
assert is_broadcastable(shape1, shape2), f"{shape1} and {shape2} are not broadcastable"
|
||||
shape1_reverse = shape1[::-1]
|
||||
shape2_reverse = shape2[::-1]
|
||||
min_common_dim = min(len(shape1), len(shape2))
|
||||
@@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
|
||||
logical_num_dims = len(logical_shape)
|
||||
physical_num_dims = len(physical_shape)
|
||||
|
||||
assert logical_num_dims >= physical_num_dims, \
|
||||
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
|
||||
assert (
|
||||
logical_num_dims >= physical_num_dims
|
||||
), "The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!"
|
||||
|
||||
# track the dim and its broadcasting type
|
||||
logical_dim_broadcast_info = {}
|
||||
@@ -85,8 +89,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
|
||||
return logical_dim_broadcast_info
|
||||
|
||||
|
||||
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
|
||||
physical_shape: torch.Size) -> ShardingSpec:
|
||||
def recover_sharding_spec_for_broadcast_shape(
|
||||
logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, physical_shape: torch.Size
|
||||
) -> ShardingSpec:
|
||||
"""
|
||||
This function computes the sharding spec for the physical shape of a broadcast tensor.
|
||||
|
||||
@@ -124,15 +129,18 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
|
||||
physical_dim_partition[physical_dim] = mesh_dim
|
||||
|
||||
physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
|
||||
entire_shape=physical_shape,
|
||||
dim_partition_dict=physical_dim_partition)
|
||||
physical_sharding_spec = ShardingSpec(
|
||||
device_mesh=logical_sharding_spec.device_mesh,
|
||||
entire_shape=physical_shape,
|
||||
dim_partition_dict=physical_dim_partition,
|
||||
)
|
||||
|
||||
return physical_sharding_spec, removed_dims
|
||||
|
||||
|
||||
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
|
||||
sharding_spec: ShardingSpec) -> CommAction:
|
||||
def comm_actions_for_oprands(
|
||||
node: Node, removed_dims: List[int], op_data: OperationData, sharding_spec: ShardingSpec
|
||||
) -> CommAction:
|
||||
"""
|
||||
This method is used to generate communication actions for oprands which lose information
|
||||
during convert logical shape to physical shape.
|
||||
@@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
|
||||
if len(removed_dims) == 1:
|
||||
# if list length is 1, extract element from list to avoid using flatten device mesh
|
||||
removed_dims = removed_dims[0]
|
||||
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
sharding_spec=sharding_spec,
|
||||
logical_process_axis=removed_dims)
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
sharding_spec=sharding_spec,
|
||||
logical_process_axis=removed_dims,
|
||||
)
|
||||
if op_data.type == OperationDataType.PARAM:
|
||||
comm_type = CommType.HOOK
|
||||
else:
|
||||
@@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
|
||||
for index, arg in enumerate(node.args):
|
||||
if op_data.name == str(arg):
|
||||
arg_index = index
|
||||
assert arg_index >= 0, f'op_data should be an argument of node.'
|
||||
assert arg_index >= 0, f"op_data should be an argument of node."
|
||||
comm_action = CommAction(
|
||||
comm_spec=comm_spec,
|
||||
comm_type=comm_type,
|
||||
|
Reference in New Issue
Block a user