[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -44,7 +44,7 @@ def is_sharded(dtensor: torch.Tensor) -> bool:
Returns:
bool: True if the tensor is sharded, False otherwise.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
@@ -77,8 +77,10 @@ def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
return dtensor
def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
'''
def _construct_default_sharding_spec(
tensor: torch.Tensor,
) -> ShardingSpec:
"""
Construct the default sharding specification for the tensor.
Args:
@@ -86,14 +88,14 @@ def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
Returns:
A `ShardingSpec` object without any sharding specified.
'''
"""
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
def _apply_layout(tensor, layout):
'''
"""
Apply the layout to the local tensor during initializing process.
'''
"""
# layout converter requires a source and target laytout
# we construct the source layer for an unsharded tensor
# and use self.dist_layer as the targer layout for the sharded tensor
@@ -115,7 +117,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
Returns:
torch.Tensor: The distributed tensor.
"""
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
# shard tensor
@@ -128,7 +130,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
'''
"""
Convert the layout of the tensor from source_spec to target_spec.
This will update the `local_tensor` and `dist_layout` in place.
@@ -136,13 +138,13 @@ def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec:
dtensor (torch.Tensor): the distributed tensor to be converted.
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
target_layout (Layout): the target layout specification.
'''
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
global_shape = get_global_shape(dtensor)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
resharded_tensor = layout_converter.apply(tensor=dtensor,
source_layout=dtensor.dist_layout,
target_layout=target_layout)
resharded_tensor = layout_converter.apply(
tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout
)
return resharded_tensor
@@ -157,7 +159,7 @@ def to_global(dtensor: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: the global tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
@@ -193,7 +195,7 @@ def shard_rowwise(
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
@@ -222,7 +224,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
@@ -230,7 +232,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
@@ -241,7 +243,7 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
param.data = dtensor
# make it distributed as well
param.dist_layout = dtensor.dist_layout
@@ -258,7 +260,7 @@ def compute_global_numel(dtensor: torch.Tensor) -> int:
Returns:
int: The global number of elements in the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
return numel
@@ -274,7 +276,7 @@ def get_layout(dtensor: torch.Tensor) -> Layout:
Layout: The layout of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout
@@ -288,7 +290,7 @@ def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
Returns:
torch.Size: The global shape of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.global_shape
@@ -302,7 +304,7 @@ def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
Returns:
DeviceMesh: The device mesh of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.device_mesh
@@ -316,7 +318,7 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
Returns:
ShardingSpec: The sharding spec of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.sharding_spec
@@ -335,7 +337,7 @@ def is_customized_distributed_tensor(tensor: torch.Tensor):
Returns:
bool: Whether the given tensor is a customized distributed tensor.
"""
return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn')
return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn")
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
@@ -402,9 +404,9 @@ def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_
Returns:
torch.Tensor: The distributed tensor.
"""
assert callable(shard_fn), 'The shard_fn must be callable.'
assert callable(gather_fn), 'The gather_fn must be callable.'
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
assert callable(shard_fn), "The shard_fn must be callable."
assert callable(gather_fn), "The gather_fn must be callable."
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
sharded_tensor = shard_fn(tensor)
@@ -428,7 +430,7 @@ def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.
Returns:
torch.Tensor: The global tensor.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
return dtensor.gather_fn(dtensor)
@@ -436,7 +438,7 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad:
"""
Convert the given customized distributed tensor to a parameter.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
@@ -451,7 +453,7 @@ def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param
"""
Convert the given customized distributed tensor to an existing parameter.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
param.data = dtensor.data
param.shard_fn = dtensor.shard_fn