[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee
2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@@ -14,24 +14,21 @@ class Layout:
Attributes:
device_mesh: the device mesh to store the tensor distributed.
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
sharding_spec: the sharding specification to describe how the tensor is sharded.
entire_shape: the entire shape of the global tensor.
global_shape: the entire shape of the global tensor.
"""
def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec,
entire_shape: torch.Size):
def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
self.device_mesh = device_mesh
self.device_type = device_type
self.sharding_spec = sharding_spec
self.entire_shape = entire_shape
self.global_shape = global_shape
self._sanity_check()
def __hash__(self) -> int:
return hash(f'{self.sharding_spec}')
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
sharded_shape = list(self.global_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
@@ -56,7 +53,7 @@ class Layout:
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim]
tensor_dim_size = self.global_shape[dim]
num_devices = 1
for element in shard_list: