mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[devops] remove post commit ci (#5566)
* [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -2,13 +2,13 @@ from .api import (
|
||||
compute_global_numel,
|
||||
customized_distributed_tensor_to_param,
|
||||
distribute_tensor,
|
||||
init_as_dtensor,
|
||||
distribute_tensor_with_customization,
|
||||
init_tensor_as_customization_distributed,
|
||||
get_device_mesh,
|
||||
get_global_shape,
|
||||
get_layout,
|
||||
get_sharding_spec,
|
||||
init_as_dtensor,
|
||||
init_tensor_as_customization_distributed,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
is_sharded,
|
||||
|
@@ -128,7 +128,10 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
|
||||
|
||||
return sharded_tensor
|
||||
|
||||
def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor:
|
||||
|
||||
def init_as_dtensor(
|
||||
tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size
|
||||
) -> torch.Tensor:
|
||||
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
||||
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||
|
||||
@@ -140,6 +143,7 @@ def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
||||
"""
|
||||
Convert the layout of the tensor from source_spec to target_spec.
|
||||
@@ -468,7 +472,6 @@ def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gat
|
||||
assert callable(gather_fn), "The gather_fn must be callable."
|
||||
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
||||
|
||||
|
||||
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
||||
tensor.shard_fn = shard_fn
|
||||
tensor.gather_fn = gather_fn
|
||||
|
Reference in New Issue
Block a user