mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[shardformer] integrated linear 1D with dtensor (#3996)
* [shardformer] integrated linear 1D with dtensor * polish code
This commit is contained in:
44
colossalai/tensor/d_tensor/api.py
Normal file
44
colossalai/tensor/d_tensor/api.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .d_tensor import DTensor
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
|
||||
"""
|
||||
Shard the first dim of the given tensor
|
||||
"""
|
||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||
if group_or_device_mesh is None:
|
||||
group_or_device_mesh = dist.GroupMember.WORLD
|
||||
|
||||
if isinstance(group_or_device_mesh, ProcessGroup):
|
||||
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
||||
else:
|
||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||
device_mesh = group_or_device_mesh
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
||||
return DTensor(tensor, device_mesh, sharding_spec)
|
||||
|
||||
|
||||
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
|
||||
"""
|
||||
Shard the first dim of the given tensor
|
||||
"""
|
||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||
if group_or_device_mesh is None:
|
||||
group_or_device_mesh = dist.GroupMember.WORLD
|
||||
|
||||
if isinstance(group_or_device_mesh, ProcessGroup):
|
||||
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
||||
else:
|
||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||
device_mesh = group_or_device_mesh
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
||||
return DTensor(tensor, device_mesh, sharding_spec)
|
@@ -34,7 +34,7 @@ class Layout:
|
||||
def get_sharded_shape_per_device(self):
|
||||
sharded_shape = list(self.entire_shape)
|
||||
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
|
||||
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
|
||||
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
|
||||
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
||||
assert sharded_shape[
|
||||
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
|
||||
@@ -45,14 +45,15 @@ class Layout:
|
||||
sharding_spec = self.sharding_spec
|
||||
|
||||
# make sure all axes in logical device mesh only be used once
|
||||
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
if element in dim_check_list:
|
||||
dim_check_list.remove(element)
|
||||
else:
|
||||
raise DuplicatedShardingDimensionError(
|
||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
||||
if self.device_mesh.logical_mesh_id is not None:
|
||||
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
if element in dim_check_list:
|
||||
dim_check_list.remove(element)
|
||||
else:
|
||||
raise DuplicatedShardingDimensionError(
|
||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
||||
|
||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
@@ -60,7 +61,7 @@ class Layout:
|
||||
num_devices = 1
|
||||
|
||||
for element in shard_list:
|
||||
num_devices *= self.device_mesh.mesh_shape[element]
|
||||
num_devices *= self.device_mesh.shape[element]
|
||||
|
||||
if tensor_dim_size % num_devices != 0:
|
||||
raise ShardingNotDivisibleError(
|
||||
|
@@ -304,7 +304,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
|
||||
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
|
||||
for dim, shard_list in source_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
legal_sharding_dims.remove(element)
|
||||
|
Reference in New Issue
Block a user