[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -21,8 +21,23 @@ from .layout import Layout
from .sharding_spec import ShardingSpec
__all__ = [
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization',
'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec'
"is_distributed_tensor",
"distribute_tensor",
"to_global",
"is_sharded",
"shard_rowwise",
"shard_colwise",
"sharded_tensor_to_param",
"compute_global_numel",
"get_sharding_spec",
"get_global_shape",
"get_device_mesh",
"redistribute",
"get_layout",
"is_customized_distributed_tensor",
"distribute_tensor_with_customization",
"to_global_for_customized_distributed_tensor",
"customized_distributed_tensor_to_param",
"Layout",
"ShardingSpec",
]

View File

@@ -44,7 +44,7 @@ def is_sharded(dtensor: torch.Tensor) -> bool:
Returns:
bool: True if the tensor is sharded, False otherwise.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
@@ -77,8 +77,10 @@ def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
return dtensor
def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
'''
def _construct_default_sharding_spec(
tensor: torch.Tensor,
) -> ShardingSpec:
"""
Construct the default sharding specification for the tensor.
Args:
@@ -86,14 +88,14 @@ def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
Returns:
A `ShardingSpec` object without any sharding specified.
'''
"""
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
def _apply_layout(tensor, layout):
'''
"""
Apply the layout to the local tensor during initializing process.
'''
"""
# layout converter requires a source and target laytout
# we construct the source layer for an unsharded tensor
# and use self.dist_layer as the targer layout for the sharded tensor
@@ -115,7 +117,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
Returns:
torch.Tensor: The distributed tensor.
"""
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
# shard tensor
@@ -128,7 +130,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
'''
"""
Convert the layout of the tensor from source_spec to target_spec.
This will update the `local_tensor` and `dist_layout` in place.
@@ -136,13 +138,13 @@ def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec:
dtensor (torch.Tensor): the distributed tensor to be converted.
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
target_layout (Layout): the target layout specification.
'''
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
global_shape = get_global_shape(dtensor)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
resharded_tensor = layout_converter.apply(tensor=dtensor,
source_layout=dtensor.dist_layout,
target_layout=target_layout)
resharded_tensor = layout_converter.apply(
tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout
)
return resharded_tensor
@@ -157,7 +159,7 @@ def to_global(dtensor: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: the global tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
@@ -193,7 +195,7 @@ def shard_rowwise(
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
@@ -222,7 +224,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
@@ -230,7 +232,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
@@ -241,7 +243,7 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
param.data = dtensor
# make it distributed as well
param.dist_layout = dtensor.dist_layout
@@ -258,7 +260,7 @@ def compute_global_numel(dtensor: torch.Tensor) -> int:
Returns:
int: The global number of elements in the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
return numel
@@ -274,7 +276,7 @@ def get_layout(dtensor: torch.Tensor) -> Layout:
Layout: The layout of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout
@@ -288,7 +290,7 @@ def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
Returns:
torch.Size: The global shape of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.global_shape
@@ -302,7 +304,7 @@ def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
Returns:
DeviceMesh: The device mesh of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.device_mesh
@@ -316,7 +318,7 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
Returns:
ShardingSpec: The sharding spec of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.sharding_spec
@@ -335,7 +337,7 @@ def is_customized_distributed_tensor(tensor: torch.Tensor):
Returns:
bool: Whether the given tensor is a customized distributed tensor.
"""
return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn')
return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn")
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
@@ -402,9 +404,9 @@ def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_
Returns:
torch.Tensor: The distributed tensor.
"""
assert callable(shard_fn), 'The shard_fn must be callable.'
assert callable(gather_fn), 'The gather_fn must be callable.'
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
assert callable(shard_fn), "The shard_fn must be callable."
assert callable(gather_fn), "The gather_fn must be callable."
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
sharded_tensor = shard_fn(tensor)
@@ -428,7 +430,7 @@ def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.
Returns:
torch.Tensor: The global tensor.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
return dtensor.gather_fn(dtensor)
@@ -436,7 +438,7 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad:
"""
Convert the given customized distributed tensor to a parameter.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
@@ -451,7 +453,7 @@ def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param
"""
Convert the given customized distributed tensor to an existing parameter.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
param.data = dtensor.data
param.shard_fn = dtensor.shard_fn

View File

@@ -6,22 +6,22 @@ import torch.distributed as dist
from torch.distributed import ReduceOp
__all__ = [
'CollectiveCommPattern',
'CommSpec',
"CollectiveCommPattern",
"CommSpec",
]
class CollectiveCommPattern(Enum):
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd"
ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd"
SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd"
ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd"
IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd"
MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd"
class CommSpec:
'''
"""
Communication spec is used to record the communication action. It converts the communication spec
to real action which will be used in runtime. It contains comm_pattern to determine the
communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
@@ -33,14 +33,16 @@ class CommSpec:
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
'''
"""
def __init__(self,
comm_pattern: CollectiveCommPattern,
process_group_dict: Dict,
gather_dim: int = None,
shard_dim: int = None,
logical_process_axis: int = None):
def __init__(
self,
comm_pattern: CollectiveCommPattern,
process_group_dict: Dict,
gather_dim: int = None,
shard_dim: int = None,
logical_process_axis: int = None,
):
self.comm_pattern = comm_pattern
self.gather_dim = gather_dim
self.shard_dim = shard_dim
@@ -71,16 +73,16 @@ class CommSpec:
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
return ''.join(res_list)
return "".join(res_list)
def covert_spec_to_action(self, tensor):
'''
"""
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
"""
if self.comm_pattern in pattern_to_func_dict:
tensor = pattern_to_func_dict[self.comm_pattern](tensor, self)
else:
@@ -89,9 +91,9 @@ class CommSpec:
def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
'''
"""
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
"""
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
@@ -103,9 +105,9 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
def _split(tensor: torch.Tensor, comm_spec: CommSpec):
'''
"""
Implement shard operation on device mesh based on information provided by comm_spec.
'''
"""
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
@@ -115,9 +117,9 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
'''
"""
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
"""
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
new_shape = list(tensor.shape)
@@ -134,9 +136,9 @@ def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
'''
"""
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
"""
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
if not tensor.is_contiguous():
tensor = tensor.contiguous()
@@ -256,11 +258,13 @@ class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
process_group_dict=comm_spec.process_group_dict,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis)
comm_spec_for_backward = CommSpec(
comm_pattern=comm_spec.comm_pattern,
process_group_dict=comm_spec.process_group_dict,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis,
)
ctx.comm_spec = comm_spec_for_backward
return output

View File

@@ -25,15 +25,16 @@ class Layout:
self._sanity_check()
def __hash__(self) -> int:
return hash(f'{self.sharding_spec}')
return hash(f"{self.sharding_spec}")
def get_sharded_shape_per_device(self):
sharded_shape = list(self.global_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
assert (
sharded_shape[dim] % shard_partitions == 0
), f"Cannot shard dimension {dim} into {shard_partitions} partitions."
sharded_shape[dim] //= shard_partitions
return torch.Size(sharded_shape)
@@ -49,7 +50,8 @@ class Layout:
dim_check_list.remove(element)
else:
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}."
)
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items():
@@ -61,5 +63,5 @@ class Layout:
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(
f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices."
)

View File

@@ -14,7 +14,7 @@ from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator,
from .sharding_spec import ShardingSpec
from .utils import get_comm_cost
__all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_options']
__all__ = ["LayoutConverter", "LayoutConverterOptions", "set_layout_converting_options"]
@dataclass
@@ -22,8 +22,8 @@ class LayoutConverterOptions:
"""
LayoutConverterOptions is a dataclass which specifies the preferences for layout converting.
"""
# TODO: layout converter option is not implemented yet
pass
def set_layout_converting_options(options: LayoutConverterOptions):
@@ -63,7 +63,7 @@ class LayoutConverter(metaclass=SingletonMeta):
self._forward_only = value
def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
'''
"""
Get all valid layouts from source_layout with single all-gather operation.
For the all-gather operation, we just care about the S dimension.
@@ -96,7 +96,7 @@ class LayoutConverter(metaclass=SingletonMeta):
Output:
[R, S1, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:0, shard_dim:0, logical_process_axis:0)
[S0, R, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1)
'''
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
source_spec = source_layout.sharding_spec
@@ -125,16 +125,19 @@ class LayoutConverter(metaclass=SingletonMeta):
comm_pattern,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
# shard_dim will be used during backward
# shard_dim will be used during backward
shard_dim=gather_dim,
logical_process_axis=logical_process_axis)
logical_process_axis=logical_process_axis,
)
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
global_shape=source_layout.global_shape)
new_layout = Layout(
device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
global_shape=source_layout.global_shape,
)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
@@ -142,7 +145,7 @@ class LayoutConverter(metaclass=SingletonMeta):
return valid_spec_dict
def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
'''
"""
Get all valid layouts from source_layout with single all-to-all operation.
For the all-to-all operation, we just care about the pairs containing S dimension.
@@ -176,7 +179,7 @@ class LayoutConverter(metaclass=SingletonMeta):
[S01, R, R]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:0, logical_process_axis: 1)
[R, S1, S0]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:0, shard_dim:2, logical_process_axis: 0)
[S0, R, S1]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:2, logical_process_axis: 1)
'''
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
@@ -224,11 +227,13 @@ class LayoutConverter(metaclass=SingletonMeta):
gather_dim = b_index
shard_dim = f_index
logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(comm_pattern,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
comm_spec = CommSpec(
comm_pattern,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis,
)
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
@@ -246,9 +251,11 @@ class LayoutConverter(metaclass=SingletonMeta):
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
global_shape=source_layout.global_shape)
new_layout = Layout(
device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
global_shape=source_layout.global_shape,
)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
@@ -256,7 +263,7 @@ class LayoutConverter(metaclass=SingletonMeta):
return valid_spec_dict
def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
'''
"""
Get all valid layouts from source_layout with single shard operation.
For the sharding operation, we just care about legal sharding dimensions.
@@ -291,7 +298,7 @@ class LayoutConverter(metaclass=SingletonMeta):
[S01, R, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:0, shard_dim:0, logical_process_axis:1)
[S0, S1, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1)
[S0, R, S1]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:2, shard_dim:2, logical_process_axis:1)
'''
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
source_spec = source_layout.sharding_spec
@@ -326,26 +333,31 @@ class LayoutConverter(metaclass=SingletonMeta):
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
shard_dim = index
logical_process_axis = shard_list[-1]
comm_spec = CommSpec(comm_pattern,
process_group_dict=process_group_dict,
gather_dim=shard_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
comm_spec = CommSpec(
comm_pattern,
process_group_dict=process_group_dict,
gather_dim=shard_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis,
)
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(dim_size=source_spec.dims,
dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
global_shape=source_layout.global_shape)
new_sharding_spec = ShardingSpec(
dim_size=source_spec.dims, dim_partition_dict=new_dim_partition_dict
)
new_layout = Layout(
device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
global_shape=source_layout.global_shape,
)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
'''
"""
Get all valid layouts from source_layout with one step transform.
Note:
@@ -358,16 +370,17 @@ class LayoutConverter(metaclass=SingletonMeta):
Return:
valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with one step transform.
'''
"""
valid_spec_dict = {}
valid_spec_dict.update(self.all_gather_transform_layouts(source_layout))
valid_spec_dict.update(self.all_to_all_transform_layout(source_layout))
valid_spec_dict.update(self.shard_transform_layout(source_layout))
return valid_spec_dict
def layout_converting(self, source_layout: Layout,
target_layout: Layout) -> Tuple[List[Layout], List[CommSpec], float]:
'''
def layout_converting(
self, source_layout: Layout, target_layout: Layout
) -> Tuple[List[Layout], List[CommSpec], float]:
"""
This method will find a path to transform source_layout to target_layout with
a greedy algorithm.
The basic idea is:
@@ -419,7 +432,7 @@ class LayoutConverter(metaclass=SingletonMeta):
output:
[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]
'''
"""
source_spec = source_layout.sharding_spec
target_spec = target_layout.sharding_spec
MAX_TRANSFORM_STEPS = 20
@@ -470,11 +483,11 @@ class LayoutConverter(metaclass=SingletonMeta):
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> Dict[str, float]:
'''
"""
Get the total communication cost of the layout converting process.
'''
"""
transform_path, comm_action_sequence = self.layout_converting(source_layout, target_layout)
total_cost = {'forward': 0.0, 'backward': 0.0, 'total': 0.0}
total_cost = {"forward": 0.0, "backward": 0.0, "total": 0.0}
for layout, comm_spec in zip(transform_path, comm_action_sequence):
cost_dict = get_comm_cost(layout, comm_spec, self.forward_only)
for key in total_cost:
@@ -482,7 +495,7 @@ class LayoutConverter(metaclass=SingletonMeta):
return total_cost
def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layout) -> torch.Tensor:
'''
"""
Apply target_layout to tensor with source layout, the transform path is generated by the
layout_converting method.
@@ -542,7 +555,7 @@ class LayoutConverter(metaclass=SingletonMeta):
[1.],
[3.],
[3.]])
'''
"""
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
for comm_spec in comm_action_sequence:
tensor = comm_spec.covert_spec_to_action(tensor)

View File

@@ -4,16 +4,16 @@ from typing import Dict, List
from ..utils import merge_same_dim_mesh_list
from .misc import ShardingOutOfIndexError
__all__ = ['DimSpec', 'ShardingException', 'ShardingSpec']
__all__ = ["DimSpec", "ShardingException", "ShardingSpec"]
ALLGATHER_COST = 20
SHARD_COST = 5
STEP_PENALTY = 6
NAN = 'nan'
NAN = "nan"
class DimSpec:
'''
"""
Sharding spec for single dimension of the sharded tensor describe the sharding dimension of
logical device mesh and give a method to compute the difference between them.
This class is used internally in ShardingSpec.
@@ -21,7 +21,7 @@ class DimSpec:
Argument:
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
Otherwise, the element in shard_list means the data will be sharded in that dimension.
'''
"""
def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0
@@ -33,41 +33,40 @@ class DimSpec:
def __repr__(self):
if self.is_replica:
return 'R'
target = 'S'
return "R"
target = "S"
for dim in self.shard_list:
target += str(dim)
return target
def _convert_str_to_shard_list(self, str_spec):
'''
"""
Convert str_spec into shard_list.
Argument:
str_spec(str): dim spec in str type.
'''
"""
if str_spec == 'R':
if str_spec == "R":
return []
if str_spec == 'S0':
if str_spec == "S0":
return [0]
if str_spec == 'S1':
if str_spec == "S1":
return [1]
if str_spec == 'S01':
if str_spec == "S01":
return [0, 1]
def build_difference_2d_dict(self):
'''
"""
Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''
"""
source_spec_list = ['R', 'S0', 'S1', 'S01']
target_spec_list = ['R', 'S0', 'S1', 'S01']
source_spec_list = ["R", "S0", "S1", "S01"]
target_spec_list = ["R", "S0", "S1", "S01"]
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
legal_sharding_dims = []
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
source_shard_list = self._convert_str_to_shard_list(source_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
@@ -77,14 +76,17 @@ class DimSpec:
difference = 0
# all_gather(source) -> target
elif len(source_shard_list
) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list:
elif (
len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list
):
difference = ALLGATHER_COST
# shard(source) -> target
elif len(source_shard_list) == len(
target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
-1] not in source_shard_list:
elif (
len(source_shard_list) == len(target_shard_list) - 1
and source_shard_list == target_shard_list[:-1]
and target_shard_list[-1] not in source_shard_list
):
difference = SHARD_COST
# S1 -> S0 or S0 -> S1
@@ -115,7 +117,7 @@ class DimSpec:
self.difference_dict = difference_dict
def dim_diff(self, other):
'''
"""
The difference between two _DimSpec.
Argument:
@@ -131,13 +133,13 @@ class DimSpec:
Output:
5
'''
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
class ShardingSpec:
'''
"""
Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like
[R, R, S0, S1], which means
@@ -145,23 +147,27 @@ class ShardingSpec:
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
and the value of the key describe which logical axis will be sharded in that dimension.
sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
"""
def __init__(self,
dim_size: int,
dim_partition_dict: Dict[int, List[int]] = None,
sharding_sequence: List[DimSpec] = None):
def __init__(
self, dim_size: int, dim_partition_dict: Dict[int, List[int]] = None, sharding_sequence: List[DimSpec] = None
):
self.dims = dim_size
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence
if self.sharding_sequence is None:
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=self.dims,
dim_partition_dict=self.dim_partition_dict)
assert (
self.dim_partition_dict is not None
), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object."
self.dim_partition_dict = merge_same_dim_mesh_list(
dim_size=self.dims, dim_partition_dict=self.dim_partition_dict
)
self.sharding_sequence = self.convert_dict_to_shard_sequence()
elif self.dim_partition_dict is None:
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
assert (
self.sharding_sequence is not None
), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object."
self.dim_partition_dict = self.convert_shard_sequence_to_dict()
self._sanity_check()
@@ -169,31 +175,32 @@ class ShardingSpec:
def _sanity_check(self):
if len(self.sharding_sequence) > self.dims:
raise ShardingOutOfIndexError(
f'sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}.')
f"sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}."
)
if list(self.dim_partition_dict.keys()) and max(list(self.dim_partition_dict.keys())) >= self.dims:
raise ShardingOutOfIndexError(
f'the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}.'
f"the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}."
)
def __repr__(self):
res_list = ["ShardingSpec:"]
res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
return ' '.join(res_list)
return " ".join(res_list)
def convert_dict_to_shard_sequence(self):
'''
"""
Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence.
'''
"""
sharding_sequence = [DimSpec([])] * self.dims
for dim, shard_list in self.dim_partition_dict.items():
sharding_sequence[dim] = DimSpec(shard_list)
return sharding_sequence
def convert_shard_sequence_to_dict(self):
'''
"""
Convert sharding_sequence into dim_partition_dict.
'''
"""
new_dim_partition_dict = {}
for index, dim_spec in enumerate(self.sharding_sequence):
if not dim_spec.is_replica:
@@ -203,7 +210,7 @@ class ShardingSpec:
return new_dim_partition_dict
def spec_diff(self, other):
'''
"""
This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
pair of sharding sequence.
@@ -228,9 +235,10 @@ class ShardingSpec:
Return:
difference(int): Difference between two ShardingSpec.
'''
"""
assert len(self.sharding_sequence) == len(
other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.'
other.sharding_sequence
), f"Cannot compare difference for two sharding specs with different length."
difference = 0
for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence):
difference += orig_dim_spec.dim_diff(other_dim_spec)

View File

@@ -7,7 +7,7 @@ from colossalai.tensor.d_tensor.layout import Layout
def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = False) -> Dict[str, float]:
'''
"""
This method is used to compute the communication cost for a given layout and comm_spec.
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
@@ -18,7 +18,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals
comm_spec: the comm_spec to instruct the communication operation.
forward_only: if it is True, we will just count the forward communication cost.
If it is False, we will count both forward and backward communication cost.
'''
"""
comm_size = reduce(operator.mul, layout.get_sharded_shape_per_device(), 1)
device_mesh = layout.device_mesh
comm_pattern = comm_spec.comm_pattern