mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -21,8 +21,23 @@ from .layout import Layout
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = [
|
||||
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
|
||||
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
|
||||
'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization',
|
||||
'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec'
|
||||
"is_distributed_tensor",
|
||||
"distribute_tensor",
|
||||
"to_global",
|
||||
"is_sharded",
|
||||
"shard_rowwise",
|
||||
"shard_colwise",
|
||||
"sharded_tensor_to_param",
|
||||
"compute_global_numel",
|
||||
"get_sharding_spec",
|
||||
"get_global_shape",
|
||||
"get_device_mesh",
|
||||
"redistribute",
|
||||
"get_layout",
|
||||
"is_customized_distributed_tensor",
|
||||
"distribute_tensor_with_customization",
|
||||
"to_global_for_customized_distributed_tensor",
|
||||
"customized_distributed_tensor_to_param",
|
||||
"Layout",
|
||||
"ShardingSpec",
|
||||
]
|
||||
|
@@ -44,7 +44,7 @@ def is_sharded(dtensor: torch.Tensor) -> bool:
|
||||
Returns:
|
||||
bool: True if the tensor is sharded, False otherwise.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
|
||||
|
||||
|
||||
@@ -77,8 +77,10 @@ def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
return dtensor
|
||||
|
||||
|
||||
def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
||||
'''
|
||||
def _construct_default_sharding_spec(
|
||||
tensor: torch.Tensor,
|
||||
) -> ShardingSpec:
|
||||
"""
|
||||
Construct the default sharding specification for the tensor.
|
||||
|
||||
Args:
|
||||
@@ -86,14 +88,14 @@ def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
||||
|
||||
Returns:
|
||||
A `ShardingSpec` object without any sharding specified.
|
||||
'''
|
||||
"""
|
||||
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
||||
|
||||
|
||||
def _apply_layout(tensor, layout):
|
||||
'''
|
||||
"""
|
||||
Apply the layout to the local tensor during initializing process.
|
||||
'''
|
||||
"""
|
||||
# layout converter requires a source and target laytout
|
||||
# we construct the source layer for an unsharded tensor
|
||||
# and use self.dist_layer as the targer layout for the sharded tensor
|
||||
@@ -115,7 +117,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor.
|
||||
"""
|
||||
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
|
||||
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
||||
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
|
||||
|
||||
# shard tensor
|
||||
@@ -128,7 +130,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
|
||||
|
||||
|
||||
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
||||
'''
|
||||
"""
|
||||
Convert the layout of the tensor from source_spec to target_spec.
|
||||
This will update the `local_tensor` and `dist_layout` in place.
|
||||
|
||||
@@ -136,13 +138,13 @@ def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec:
|
||||
dtensor (torch.Tensor): the distributed tensor to be converted.
|
||||
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
|
||||
target_layout (Layout): the target layout specification.
|
||||
'''
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
global_shape = get_global_shape(dtensor)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||
resharded_tensor = layout_converter.apply(tensor=dtensor,
|
||||
source_layout=dtensor.dist_layout,
|
||||
target_layout=target_layout)
|
||||
resharded_tensor = layout_converter.apply(
|
||||
tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout
|
||||
)
|
||||
return resharded_tensor
|
||||
|
||||
|
||||
@@ -157,7 +159,7 @@ def to_global(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
Returns:
|
||||
torch.Tensor: the global tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
layout_converter = LayoutConverter()
|
||||
|
||||
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
|
||||
@@ -193,7 +195,7 @@ def shard_rowwise(
|
||||
if isinstance(group_or_device_mesh, ProcessGroup):
|
||||
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
||||
else:
|
||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
|
||||
device_mesh = group_or_device_mesh
|
||||
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
||||
@@ -222,7 +224,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
|
||||
if isinstance(group_or_device_mesh, ProcessGroup):
|
||||
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
||||
else:
|
||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
|
||||
device_mesh = group_or_device_mesh
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
||||
|
||||
@@ -230,7 +232,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
|
||||
|
||||
|
||||
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
||||
|
||||
# make it distributed as well
|
||||
@@ -241,7 +243,7 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
||||
|
||||
|
||||
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
param.data = dtensor
|
||||
# make it distributed as well
|
||||
param.dist_layout = dtensor.dist_layout
|
||||
@@ -258,7 +260,7 @@ def compute_global_numel(dtensor: torch.Tensor) -> int:
|
||||
Returns:
|
||||
int: The global number of elements in the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
|
||||
return numel
|
||||
|
||||
@@ -274,7 +276,7 @@ def get_layout(dtensor: torch.Tensor) -> Layout:
|
||||
Layout: The layout of the distributed tensor.
|
||||
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
return dtensor.dist_layout
|
||||
|
||||
|
||||
@@ -288,7 +290,7 @@ def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
|
||||
Returns:
|
||||
torch.Size: The global shape of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
return dtensor.dist_layout.global_shape
|
||||
|
||||
|
||||
@@ -302,7 +304,7 @@ def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
|
||||
Returns:
|
||||
DeviceMesh: The device mesh of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
return dtensor.dist_layout.device_mesh
|
||||
|
||||
|
||||
@@ -316,7 +318,7 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
|
||||
Returns:
|
||||
ShardingSpec: The sharding spec of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
||||
return dtensor.dist_layout.sharding_spec
|
||||
|
||||
|
||||
@@ -335,7 +337,7 @@ def is_customized_distributed_tensor(tensor: torch.Tensor):
|
||||
Returns:
|
||||
bool: Whether the given tensor is a customized distributed tensor.
|
||||
"""
|
||||
return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn')
|
||||
return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn")
|
||||
|
||||
|
||||
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
@@ -402,9 +404,9 @@ def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor.
|
||||
"""
|
||||
assert callable(shard_fn), 'The shard_fn must be callable.'
|
||||
assert callable(gather_fn), 'The gather_fn must be callable.'
|
||||
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
|
||||
assert callable(shard_fn), "The shard_fn must be callable."
|
||||
assert callable(gather_fn), "The gather_fn must be callable."
|
||||
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
||||
|
||||
sharded_tensor = shard_fn(tensor)
|
||||
|
||||
@@ -428,7 +430,7 @@ def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.
|
||||
Returns:
|
||||
torch.Tensor: The global tensor.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
||||
return dtensor.gather_fn(dtensor)
|
||||
|
||||
|
||||
@@ -436,7 +438,7 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad:
|
||||
"""
|
||||
Convert the given customized distributed tensor to a parameter.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
||||
|
||||
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
||||
|
||||
@@ -451,7 +453,7 @@ def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param
|
||||
"""
|
||||
Convert the given customized distributed tensor to an existing parameter.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
||||
|
||||
param.data = dtensor.data
|
||||
param.shard_fn = dtensor.shard_fn
|
||||
|
@@ -6,22 +6,22 @@ import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
__all__ = [
|
||||
'CollectiveCommPattern',
|
||||
'CommSpec',
|
||||
"CollectiveCommPattern",
|
||||
"CommSpec",
|
||||
]
|
||||
|
||||
|
||||
class CollectiveCommPattern(Enum):
|
||||
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
|
||||
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
|
||||
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
|
||||
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
|
||||
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
|
||||
GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd"
|
||||
ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd"
|
||||
SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd"
|
||||
ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd"
|
||||
IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd"
|
||||
MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd"
|
||||
|
||||
|
||||
class CommSpec:
|
||||
'''
|
||||
"""
|
||||
Communication spec is used to record the communication action. It converts the communication spec
|
||||
to real action which will be used in runtime. It contains comm_pattern to determine the
|
||||
communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
|
||||
@@ -33,14 +33,16 @@ class CommSpec:
|
||||
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
||||
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
||||
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
comm_pattern: CollectiveCommPattern,
|
||||
process_group_dict: Dict,
|
||||
gather_dim: int = None,
|
||||
shard_dim: int = None,
|
||||
logical_process_axis: int = None):
|
||||
def __init__(
|
||||
self,
|
||||
comm_pattern: CollectiveCommPattern,
|
||||
process_group_dict: Dict,
|
||||
gather_dim: int = None,
|
||||
shard_dim: int = None,
|
||||
logical_process_axis: int = None,
|
||||
):
|
||||
self.comm_pattern = comm_pattern
|
||||
self.gather_dim = gather_dim
|
||||
self.shard_dim = shard_dim
|
||||
@@ -71,16 +73,16 @@ class CommSpec:
|
||||
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
|
||||
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
||||
|
||||
return ''.join(res_list)
|
||||
return "".join(res_list)
|
||||
|
||||
def covert_spec_to_action(self, tensor):
|
||||
'''
|
||||
"""
|
||||
Convert CommSpec into runtime action, implement real collection communication to target tensor.
|
||||
The collection communication action is directed by the CommSpec.
|
||||
|
||||
Argument:
|
||||
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
|
||||
'''
|
||||
"""
|
||||
if self.comm_pattern in pattern_to_func_dict:
|
||||
tensor = pattern_to_func_dict[self.comm_pattern](tensor, self)
|
||||
else:
|
||||
@@ -89,9 +91,9 @@ class CommSpec:
|
||||
|
||||
|
||||
def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
"""
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
@@ -103,9 +105,9 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
|
||||
|
||||
def _split(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
"""
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||
@@ -115,9 +117,9 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
|
||||
|
||||
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
"""
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
new_shape = list(tensor.shape)
|
||||
@@ -134,9 +136,9 @@ def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
|
||||
|
||||
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
|
||||
'''
|
||||
"""
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
@@ -256,11 +258,13 @@ class _AllToAll(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, comm_spec):
|
||||
output = _all_to_all(input_, comm_spec)
|
||||
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
|
||||
process_group_dict=comm_spec.process_group_dict,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis)
|
||||
comm_spec_for_backward = CommSpec(
|
||||
comm_pattern=comm_spec.comm_pattern,
|
||||
process_group_dict=comm_spec.process_group_dict,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis,
|
||||
)
|
||||
ctx.comm_spec = comm_spec_for_backward
|
||||
return output
|
||||
|
||||
|
@@ -25,15 +25,16 @@ class Layout:
|
||||
self._sanity_check()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(f'{self.sharding_spec}')
|
||||
return hash(f"{self.sharding_spec}")
|
||||
|
||||
def get_sharded_shape_per_device(self):
|
||||
sharded_shape = list(self.global_shape)
|
||||
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
|
||||
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
|
||||
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
||||
assert sharded_shape[
|
||||
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
|
||||
assert (
|
||||
sharded_shape[dim] % shard_partitions == 0
|
||||
), f"Cannot shard dimension {dim} into {shard_partitions} partitions."
|
||||
sharded_shape[dim] //= shard_partitions
|
||||
return torch.Size(sharded_shape)
|
||||
|
||||
@@ -49,7 +50,8 @@ class Layout:
|
||||
dim_check_list.remove(element)
|
||||
else:
|
||||
raise DuplicatedShardingDimensionError(
|
||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}."
|
||||
)
|
||||
|
||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
@@ -61,5 +63,5 @@ class Layout:
|
||||
|
||||
if tensor_dim_size % num_devices != 0:
|
||||
raise ShardingNotDivisibleError(
|
||||
f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
|
||||
f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices."
|
||||
)
|
||||
|
@@ -14,7 +14,7 @@ from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator,
|
||||
from .sharding_spec import ShardingSpec
|
||||
from .utils import get_comm_cost
|
||||
|
||||
__all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_options']
|
||||
__all__ = ["LayoutConverter", "LayoutConverterOptions", "set_layout_converting_options"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -22,8 +22,8 @@ class LayoutConverterOptions:
|
||||
"""
|
||||
LayoutConverterOptions is a dataclass which specifies the preferences for layout converting.
|
||||
"""
|
||||
|
||||
# TODO: layout converter option is not implemented yet
|
||||
pass
|
||||
|
||||
|
||||
def set_layout_converting_options(options: LayoutConverterOptions):
|
||||
@@ -63,7 +63,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
self._forward_only = value
|
||||
|
||||
def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
|
||||
'''
|
||||
"""
|
||||
Get all valid layouts from source_layout with single all-gather operation.
|
||||
For the all-gather operation, we just care about the S dimension.
|
||||
|
||||
@@ -96,7 +96,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
Output:
|
||||
[R, S1, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:0, shard_dim:0, logical_process_axis:0)
|
||||
[S0, R, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1)
|
||||
'''
|
||||
"""
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
@@ -125,16 +125,19 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
comm_pattern,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=gather_dim,
|
||||
# shard_dim will be used during backward
|
||||
# shard_dim will be used during backward
|
||||
shard_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
logical_process_axis=logical_process_axis,
|
||||
)
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
global_shape=source_layout.global_shape)
|
||||
new_layout = Layout(
|
||||
device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
global_shape=source_layout.global_shape,
|
||||
)
|
||||
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
@@ -142,7 +145,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
return valid_spec_dict
|
||||
|
||||
def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
|
||||
'''
|
||||
"""
|
||||
Get all valid layouts from source_layout with single all-to-all operation.
|
||||
For the all-to-all operation, we just care about the pairs containing S dimension.
|
||||
|
||||
@@ -176,7 +179,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
[S01, R, R]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:0, logical_process_axis: 1)
|
||||
[R, S1, S0]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:0, shard_dim:2, logical_process_axis: 0)
|
||||
[S0, R, S1]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:2, logical_process_axis: 1)
|
||||
'''
|
||||
"""
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
|
||||
|
||||
@@ -224,11 +227,13 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
gather_dim = b_index
|
||||
shard_dim = f_index
|
||||
logical_process_axis = b_target_pair[1][-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=gather_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=gather_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis,
|
||||
)
|
||||
|
||||
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
|
||||
|
||||
@@ -246,9 +251,11 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
global_shape=source_layout.global_shape)
|
||||
new_layout = Layout(
|
||||
device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
global_shape=source_layout.global_shape,
|
||||
)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
pass
|
||||
@@ -256,7 +263,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
return valid_spec_dict
|
||||
|
||||
def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
|
||||
'''
|
||||
"""
|
||||
Get all valid layouts from source_layout with single shard operation.
|
||||
For the sharding operation, we just care about legal sharding dimensions.
|
||||
|
||||
@@ -291,7 +298,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
[S01, R, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:0, shard_dim:0, logical_process_axis:1)
|
||||
[S0, S1, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1)
|
||||
[S0, R, S1]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:2, shard_dim:2, logical_process_axis:1)
|
||||
'''
|
||||
"""
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
@@ -326,26 +333,31 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
||||
shard_dim = index
|
||||
logical_process_axis = shard_list[-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=shard_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=shard_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis,
|
||||
)
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(dim_size=source_spec.dims,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
global_shape=source_layout.global_shape)
|
||||
new_sharding_spec = ShardingSpec(
|
||||
dim_size=source_spec.dims, dim_partition_dict=new_dim_partition_dict
|
||||
)
|
||||
new_layout = Layout(
|
||||
device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
global_shape=source_layout.global_shape,
|
||||
)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
|
||||
'''
|
||||
"""
|
||||
Get all valid layouts from source_layout with one step transform.
|
||||
|
||||
Note:
|
||||
@@ -358,16 +370,17 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with one step transform.
|
||||
'''
|
||||
"""
|
||||
valid_spec_dict = {}
|
||||
valid_spec_dict.update(self.all_gather_transform_layouts(source_layout))
|
||||
valid_spec_dict.update(self.all_to_all_transform_layout(source_layout))
|
||||
valid_spec_dict.update(self.shard_transform_layout(source_layout))
|
||||
return valid_spec_dict
|
||||
|
||||
def layout_converting(self, source_layout: Layout,
|
||||
target_layout: Layout) -> Tuple[List[Layout], List[CommSpec], float]:
|
||||
'''
|
||||
def layout_converting(
|
||||
self, source_layout: Layout, target_layout: Layout
|
||||
) -> Tuple[List[Layout], List[CommSpec], float]:
|
||||
"""
|
||||
This method will find a path to transform source_layout to target_layout with
|
||||
a greedy algorithm.
|
||||
The basic idea is:
|
||||
@@ -419,7 +432,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
|
||||
output:
|
||||
[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]
|
||||
'''
|
||||
"""
|
||||
source_spec = source_layout.sharding_spec
|
||||
target_spec = target_layout.sharding_spec
|
||||
MAX_TRANSFORM_STEPS = 20
|
||||
@@ -470,11 +483,11 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
|
||||
|
||||
def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> Dict[str, float]:
|
||||
'''
|
||||
"""
|
||||
Get the total communication cost of the layout converting process.
|
||||
'''
|
||||
"""
|
||||
transform_path, comm_action_sequence = self.layout_converting(source_layout, target_layout)
|
||||
total_cost = {'forward': 0.0, 'backward': 0.0, 'total': 0.0}
|
||||
total_cost = {"forward": 0.0, "backward": 0.0, "total": 0.0}
|
||||
for layout, comm_spec in zip(transform_path, comm_action_sequence):
|
||||
cost_dict = get_comm_cost(layout, comm_spec, self.forward_only)
|
||||
for key in total_cost:
|
||||
@@ -482,7 +495,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
return total_cost
|
||||
|
||||
def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layout) -> torch.Tensor:
|
||||
'''
|
||||
"""
|
||||
Apply target_layout to tensor with source layout, the transform path is generated by the
|
||||
layout_converting method.
|
||||
|
||||
@@ -542,7 +555,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
[1.],
|
||||
[3.],
|
||||
[3.]])
|
||||
'''
|
||||
"""
|
||||
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
|
||||
for comm_spec in comm_action_sequence:
|
||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||
|
@@ -4,16 +4,16 @@ from typing import Dict, List
|
||||
from ..utils import merge_same_dim_mesh_list
|
||||
from .misc import ShardingOutOfIndexError
|
||||
|
||||
__all__ = ['DimSpec', 'ShardingException', 'ShardingSpec']
|
||||
__all__ = ["DimSpec", "ShardingException", "ShardingSpec"]
|
||||
|
||||
ALLGATHER_COST = 20
|
||||
SHARD_COST = 5
|
||||
STEP_PENALTY = 6
|
||||
NAN = 'nan'
|
||||
NAN = "nan"
|
||||
|
||||
|
||||
class DimSpec:
|
||||
'''
|
||||
"""
|
||||
Sharding spec for single dimension of the sharded tensor describe the sharding dimension of
|
||||
logical device mesh and give a method to compute the difference between them.
|
||||
This class is used internally in ShardingSpec.
|
||||
@@ -21,7 +21,7 @@ class DimSpec:
|
||||
Argument:
|
||||
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
|
||||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, shard_list):
|
||||
self.is_replica = len(shard_list) == 0
|
||||
@@ -33,41 +33,40 @@ class DimSpec:
|
||||
|
||||
def __repr__(self):
|
||||
if self.is_replica:
|
||||
return 'R'
|
||||
target = 'S'
|
||||
return "R"
|
||||
target = "S"
|
||||
for dim in self.shard_list:
|
||||
target += str(dim)
|
||||
return target
|
||||
|
||||
def _convert_str_to_shard_list(self, str_spec):
|
||||
'''
|
||||
"""
|
||||
Convert str_spec into shard_list.
|
||||
|
||||
Argument:
|
||||
str_spec(str): dim spec in str type.
|
||||
'''
|
||||
"""
|
||||
|
||||
if str_spec == 'R':
|
||||
if str_spec == "R":
|
||||
return []
|
||||
if str_spec == 'S0':
|
||||
if str_spec == "S0":
|
||||
return [0]
|
||||
if str_spec == 'S1':
|
||||
if str_spec == "S1":
|
||||
return [1]
|
||||
if str_spec == 'S01':
|
||||
if str_spec == "S01":
|
||||
return [0, 1]
|
||||
|
||||
def build_difference_2d_dict(self):
|
||||
'''
|
||||
"""
|
||||
Build a difference mapping for 2D device mesh case. It will be used to
|
||||
compute the difference between DimSpec pairs.
|
||||
'''
|
||||
"""
|
||||
|
||||
source_spec_list = ['R', 'S0', 'S1', 'S01']
|
||||
target_spec_list = ['R', 'S0', 'S1', 'S01']
|
||||
source_spec_list = ["R", "S0", "S1", "S01"]
|
||||
target_spec_list = ["R", "S0", "S1", "S01"]
|
||||
difference_dict = {}
|
||||
for source_spec in source_spec_list:
|
||||
for target_spec in target_spec_list:
|
||||
legal_sharding_dims = []
|
||||
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
|
||||
source_shard_list = self._convert_str_to_shard_list(source_spec)
|
||||
target_shard_list = self._convert_str_to_shard_list(target_spec)
|
||||
@@ -77,14 +76,17 @@ class DimSpec:
|
||||
difference = 0
|
||||
|
||||
# all_gather(source) -> target
|
||||
elif len(source_shard_list
|
||||
) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list:
|
||||
elif (
|
||||
len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list
|
||||
):
|
||||
difference = ALLGATHER_COST
|
||||
|
||||
# shard(source) -> target
|
||||
elif len(source_shard_list) == len(
|
||||
target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
|
||||
-1] not in source_shard_list:
|
||||
elif (
|
||||
len(source_shard_list) == len(target_shard_list) - 1
|
||||
and source_shard_list == target_shard_list[:-1]
|
||||
and target_shard_list[-1] not in source_shard_list
|
||||
):
|
||||
difference = SHARD_COST
|
||||
|
||||
# S1 -> S0 or S0 -> S1
|
||||
@@ -115,7 +117,7 @@ class DimSpec:
|
||||
self.difference_dict = difference_dict
|
||||
|
||||
def dim_diff(self, other):
|
||||
'''
|
||||
"""
|
||||
The difference between two _DimSpec.
|
||||
|
||||
Argument:
|
||||
@@ -131,13 +133,13 @@ class DimSpec:
|
||||
|
||||
Output:
|
||||
5
|
||||
'''
|
||||
"""
|
||||
difference = self.difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
|
||||
|
||||
class ShardingSpec:
|
||||
'''
|
||||
"""
|
||||
Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like
|
||||
[R, R, S0, S1], which means
|
||||
|
||||
@@ -145,23 +147,27 @@ class ShardingSpec:
|
||||
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
|
||||
and the value of the key describe which logical axis will be sharded in that dimension.
|
||||
sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim_size: int,
|
||||
dim_partition_dict: Dict[int, List[int]] = None,
|
||||
sharding_sequence: List[DimSpec] = None):
|
||||
def __init__(
|
||||
self, dim_size: int, dim_partition_dict: Dict[int, List[int]] = None, sharding_sequence: List[DimSpec] = None
|
||||
):
|
||||
self.dims = dim_size
|
||||
self.dim_partition_dict = dim_partition_dict
|
||||
self.sharding_sequence = sharding_sequence
|
||||
if self.sharding_sequence is None:
|
||||
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
|
||||
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=self.dims,
|
||||
dim_partition_dict=self.dim_partition_dict)
|
||||
assert (
|
||||
self.dim_partition_dict is not None
|
||||
), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object."
|
||||
self.dim_partition_dict = merge_same_dim_mesh_list(
|
||||
dim_size=self.dims, dim_partition_dict=self.dim_partition_dict
|
||||
)
|
||||
self.sharding_sequence = self.convert_dict_to_shard_sequence()
|
||||
|
||||
elif self.dim_partition_dict is None:
|
||||
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
|
||||
assert (
|
||||
self.sharding_sequence is not None
|
||||
), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object."
|
||||
self.dim_partition_dict = self.convert_shard_sequence_to_dict()
|
||||
|
||||
self._sanity_check()
|
||||
@@ -169,31 +175,32 @@ class ShardingSpec:
|
||||
def _sanity_check(self):
|
||||
if len(self.sharding_sequence) > self.dims:
|
||||
raise ShardingOutOfIndexError(
|
||||
f'sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}.')
|
||||
f"sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}."
|
||||
)
|
||||
|
||||
if list(self.dim_partition_dict.keys()) and max(list(self.dim_partition_dict.keys())) >= self.dims:
|
||||
raise ShardingOutOfIndexError(
|
||||
f'the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}.'
|
||||
f"the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}."
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
res_list = ["ShardingSpec:"]
|
||||
res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
|
||||
return ' '.join(res_list)
|
||||
return " ".join(res_list)
|
||||
|
||||
def convert_dict_to_shard_sequence(self):
|
||||
'''
|
||||
"""
|
||||
Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence.
|
||||
'''
|
||||
"""
|
||||
sharding_sequence = [DimSpec([])] * self.dims
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
sharding_sequence[dim] = DimSpec(shard_list)
|
||||
return sharding_sequence
|
||||
|
||||
def convert_shard_sequence_to_dict(self):
|
||||
'''
|
||||
"""
|
||||
Convert sharding_sequence into dim_partition_dict.
|
||||
'''
|
||||
"""
|
||||
new_dim_partition_dict = {}
|
||||
for index, dim_spec in enumerate(self.sharding_sequence):
|
||||
if not dim_spec.is_replica:
|
||||
@@ -203,7 +210,7 @@ class ShardingSpec:
|
||||
return new_dim_partition_dict
|
||||
|
||||
def spec_diff(self, other):
|
||||
'''
|
||||
"""
|
||||
This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
|
||||
pair of sharding sequence.
|
||||
|
||||
@@ -228,9 +235,10 @@ class ShardingSpec:
|
||||
|
||||
Return:
|
||||
difference(int): Difference between two ShardingSpec.
|
||||
'''
|
||||
"""
|
||||
assert len(self.sharding_sequence) == len(
|
||||
other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.'
|
||||
other.sharding_sequence
|
||||
), f"Cannot compare difference for two sharding specs with different length."
|
||||
difference = 0
|
||||
for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence):
|
||||
difference += orig_dim_spec.dim_diff(other_dim_spec)
|
||||
|
@@ -7,7 +7,7 @@ from colossalai.tensor.d_tensor.layout import Layout
|
||||
|
||||
|
||||
def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = False) -> Dict[str, float]:
|
||||
'''
|
||||
"""
|
||||
This method is used to compute the communication cost for a given layout and comm_spec.
|
||||
|
||||
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
|
||||
@@ -18,7 +18,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals
|
||||
comm_spec: the comm_spec to instruct the communication operation.
|
||||
forward_only: if it is True, we will just count the forward communication cost.
|
||||
If it is False, we will count both forward and backward communication cost.
|
||||
'''
|
||||
"""
|
||||
comm_size = reduce(operator.mul, layout.get_sharded_shape_per_device(), 1)
|
||||
device_mesh = layout.device_mesh
|
||||
comm_pattern = comm_spec.comm_pattern
|
||||
|
Reference in New Issue
Block a user