mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -8,8 +8,11 @@ import torch
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = [
|
||||
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
"transpose_partition_dim",
|
||||
"update_partition_dim",
|
||||
"enumerate_all_possible_1d_sharding",
|
||||
"enumerate_all_possible_2d_sharding",
|
||||
"generate_sharding_size",
|
||||
]
|
||||
|
||||
|
||||
@@ -22,8 +25,7 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
|
||||
dim1 (int): the tensor dimension to switch
|
||||
dim2 (int): the tensor dimension to switch
|
||||
"""
|
||||
assert len(sharding_spec.entire_shape) >= 2, \
|
||||
'The entire_shape of the sharding spec must have at least 2 dimensions'
|
||||
assert len(sharding_spec.entire_shape) >= 2, "The entire_shape of the sharding spec must have at least 2 dimensions"
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
|
||||
# transpose the dim partition
|
||||
@@ -45,10 +47,9 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
|
||||
return sharding_spec
|
||||
|
||||
|
||||
def update_partition_dim(sharding_spec: ShardingSpec,
|
||||
dim_mapping: Dict[int, int],
|
||||
physical_shape: torch.Size,
|
||||
inplace: bool = False):
|
||||
def update_partition_dim(
|
||||
sharding_spec: ShardingSpec, dim_mapping: Dict[int, int], physical_shape: torch.Size, inplace: bool = False
|
||||
):
|
||||
"""
|
||||
This method is used to update the partition dim dict from the logical one to the physical one.
|
||||
|
||||
@@ -78,9 +79,9 @@ def update_partition_dim(sharding_spec: ShardingSpec,
|
||||
new_dim_partition_dict[tensor_dim] = mesh_dims
|
||||
|
||||
# update sharding spec
|
||||
current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
|
||||
entire_shape=physical_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
current_sharding_spec.__init__(
|
||||
device_mesh=sharding_spec.device_mesh, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict
|
||||
)
|
||||
return current_sharding_spec
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user