mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -44,7 +44,7 @@ def is_sharded(dtensor: torch.Tensor) -> bool:
|
||||
Returns:
|
||||
bool: True if the tensor is sharded, False otherwise.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
|
||||
|
||||
|
||||
@@ -77,8 +77,10 @@ def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
return dtensor
|
||||
|
||||
|
||||
def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
||||
'''
|
||||
def _construct_default_sharding_spec(
|
||||
tensor: torch.Tensor,
|
||||
) -> ShardingSpec:
|
||||
"""
|
||||
Construct the default sharding specification for the tensor.
|
||||
|
||||
Args:
|
||||
@@ -86,14 +88,14 @@ def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
||||
|
||||
Returns:
|
||||
A `ShardingSpec` object without any sharding specified.
|
||||
'''
|
||||
"""
|
||||
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
||||
|
||||
|
||||
def _apply_layout(tensor, layout):
|
||||
'''
|
||||
"""
|
||||
Apply the layout to the local tensor during initializing process.
|
||||
'''
|
||||
"""
|
||||
# layout converter requires a source and target laytout
|
||||
# we construct the source layer for an unsharded tensor
|
||||
# and use self.dist_layer as the targer layout for the sharded tensor
|
||||
@@ -115,7 +117,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor.
|
||||
"""
|
||||
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
|
||||
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
||||
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
|
||||
|
||||
# shard tensor
|
||||
@@ -128,7 +130,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
|
||||
|
||||
|
||||
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
||||
'''
|
||||
"""
|
||||
Convert the layout of the tensor from source_spec to target_spec.
|
||||
This will update the `local_tensor` and `dist_layout` in place.
|
||||
|
||||
@@ -136,13 +138,13 @@ def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec:
|
||||
dtensor (torch.Tensor): the distributed tensor to be converted.
|
||||
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
|
||||
target_layout (Layout): the target layout specification.
|
||||
'''
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
global_shape = get_global_shape(dtensor)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||
resharded_tensor = layout_converter.apply(tensor=dtensor,
|
||||
source_layout=dtensor.dist_layout,
|
||||
target_layout=target_layout)
|
||||
resharded_tensor = layout_converter.apply(
|
||||
tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout
|
||||
)
|
||||
return resharded_tensor
|
||||
|
||||
|
||||
@@ -157,7 +159,7 @@ def to_global(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
Returns:
|
||||
torch.Tensor: the global tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
layout_converter = LayoutConverter()
|
||||
|
||||
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
|
||||
@@ -193,7 +195,7 @@ def shard_rowwise(
|
||||
if isinstance(group_or_device_mesh, ProcessGroup):
|
||||
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
||||
else:
|
||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
|
||||
device_mesh = group_or_device_mesh
|
||||
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
||||
@@ -222,7 +224,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
|
||||
if isinstance(group_or_device_mesh, ProcessGroup):
|
||||
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
||||
else:
|
||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
|
||||
device_mesh = group_or_device_mesh
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
||||
|
||||
@@ -230,7 +232,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
|
||||
|
||||
|
||||
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
||||
|
||||
# make it distributed as well
|
||||
@@ -241,7 +243,7 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
||||
|
||||
|
||||
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
param.data = dtensor
|
||||
# make it distributed as well
|
||||
param.dist_layout = dtensor.dist_layout
|
||||
@@ -258,7 +260,7 @@ def compute_global_numel(dtensor: torch.Tensor) -> int:
|
||||
Returns:
|
||||
int: The global number of elements in the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
|
||||
return numel
|
||||
|
||||
@@ -274,7 +276,7 @@ def get_layout(dtensor: torch.Tensor) -> Layout:
|
||||
Layout: The layout of the distributed tensor.
|
||||
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
return dtensor.dist_layout
|
||||
|
||||
|
||||
@@ -288,7 +290,7 @@ def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
|
||||
Returns:
|
||||
torch.Size: The global shape of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
return dtensor.dist_layout.global_shape
|
||||
|
||||
|
||||
@@ -302,7 +304,7 @@ def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
|
||||
Returns:
|
||||
DeviceMesh: The device mesh of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
return dtensor.dist_layout.device_mesh
|
||||
|
||||
|
||||
@@ -316,7 +318,7 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
|
||||
Returns:
|
||||
ShardingSpec: The sharding spec of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
return dtensor.dist_layout.sharding_spec
|
||||
|
||||
|
||||
@@ -335,7 +337,7 @@ def is_customized_distributed_tensor(tensor: torch.Tensor):
|
||||
Returns:
|
||||
bool: Whether the given tensor is a customized distributed tensor.
|
||||
"""
|
||||
return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn')
|
||||
return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn")
|
||||
|
||||
|
||||
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
@@ -402,9 +404,9 @@ def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor.
|
||||
"""
|
||||
assert callable(shard_fn), 'The shard_fn must be callable.'
|
||||
assert callable(gather_fn), 'The gather_fn must be callable.'
|
||||
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
|
||||
assert callable(shard_fn), "The shard_fn must be callable."
|
||||
assert callable(gather_fn), "The gather_fn must be callable."
|
||||
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
||||
|
||||
sharded_tensor = shard_fn(tensor)
|
||||
|
||||
@@ -428,7 +430,7 @@ def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.
|
||||
Returns:
|
||||
torch.Tensor: The global tensor.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
||||
return dtensor.gather_fn(dtensor)
|
||||
|
||||
|
||||
@@ -436,7 +438,7 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad:
|
||||
"""
|
||||
Convert the given customized distributed tensor to a parameter.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
||||
|
||||
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
||||
|
||||
@@ -451,7 +453,7 @@ def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param
|
||||
"""
|
||||
Convert the given customized distributed tensor to an existing parameter.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
||||
|
||||
param.data = dtensor.data
|
||||
param.shard_fn = dtensor.shard_fn
|
||||
|
Reference in New Issue
Block a user