mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[dtensor] updated api and doc (#3845)
This commit is contained in:
@@ -11,28 +11,32 @@ from .sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
class Layout:
|
||||
"""Layout of a tensor.
|
||||
"""
|
||||
Layout of a tensor refers to the tensor placement on the device mesh and how the tensor is sharded over the devices.
|
||||
|
||||
Attributes:
|
||||
device_mesh: the device mesh to store the tensor distributed.
|
||||
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
|
||||
sharding_spec: the sharding specification to describe how the tensor is sharded.
|
||||
entire_shape: the entire shape of the global tensor.
|
||||
Args:
|
||||
device_mesh (`DeviceMesh`): the device mesh to store the tensor distributed.
|
||||
sharding_spec (`ShardingSpec`): the sharding specification to describe how the tensor is sharded.
|
||||
global_shape (`torch.Size`): the entire shape of the global tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec,
|
||||
entire_shape: torch.Size):
|
||||
def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
|
||||
self.device_mesh = device_mesh
|
||||
self.device_type = device_type
|
||||
self.sharding_spec = sharding_spec
|
||||
self.entire_shape = entire_shape
|
||||
self.global_shape = global_shape
|
||||
self._sanity_check()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(f'{self.sharding_spec}')
|
||||
|
||||
def get_sharded_shape_per_device(self):
|
||||
sharded_shape = list(self.entire_shape)
|
||||
def get_sharded_shape_per_device(self) -> torch.Size:
|
||||
"""
|
||||
Compute the shape of the sharded tensor on each device.
|
||||
|
||||
Returns:
|
||||
`torch.Size`: the shape of the sharded tensor on each device.
|
||||
"""
|
||||
sharded_shape = list(self.global_shape)
|
||||
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
|
||||
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
|
||||
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
||||
@@ -56,7 +60,7 @@ class Layout:
|
||||
|
||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
tensor_dim_size = self.entire_shape[dim]
|
||||
tensor_dim_size = self.global_shape[dim]
|
||||
num_devices = 1
|
||||
|
||||
for element in shard_list:
|
||||
|
Reference in New Issue
Block a user