[shardformer] integrated linear 1D with dtensor (#3996)

* [shardformer] integrated linear 1D with dtensor

* polish code
This commit is contained in:
Frank Lee
2023-06-15 18:03:38 +08:00
parent d3bc530849
commit 015af592f8
9 changed files with 707 additions and 408 deletions

View File

@@ -0,0 +1,44 @@
from typing import Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.device.device_mesh import DeviceMesh
from .d_tensor import DTensor
from .sharding_spec import ShardingSpec
def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
"""
Shard the first dim of the given tensor
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
return DTensor(tensor, device_mesh, sharding_spec)
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
"""
Shard the first dim of the given tensor
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
return DTensor(tensor, device_mesh, sharding_spec)

View File

@@ -34,7 +34,7 @@ class Layout:
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
@@ -45,14 +45,15 @@ class Layout:
sharding_spec = self.sharding_spec
# make sure all axes in logical device mesh only be used once
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in sharding_spec.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
if self.device_mesh.logical_mesh_id is not None:
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in sharding_spec.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items():
@@ -60,7 +61,7 @@ class Layout:
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
num_devices *= self.device_mesh.shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(

View File

@@ -304,7 +304,7 @@ class LayoutConverter(metaclass=SingletonMeta):
process_groups_dict = source_layout.device_mesh.process_groups_dict
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
legal_sharding_dims.remove(element)