mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec):
|
||||
'''
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
|
||||
]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(comm_spec.device_mesh.shape[comm_spec.logical_process_axis])
|
||||
]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _split(tensor, comm_spec):
|
||||
'''
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, _ in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||
start = length * dist.get_rank(process_group)
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_to_all(tensor, comm_spec):
|
||||
'''
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [
|
||||
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
input_tensor_list = [
|
||||
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
|
||||
]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // world_size
|
||||
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_reduce(tensor, comm_spec, async_op=False):
|
||||
'''
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
|
||||
|
||||
def _mix_gather(tensor, comm_spec):
|
||||
@@ -128,7 +125,7 @@ def _mix_gather(tensor, comm_spec):
|
||||
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
|
||||
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
|
||||
'''
|
||||
total_slices = comm_spec.device_mesh.mesh_shape[0]
|
||||
total_slices = comm_spec.device_mesh.shape[0]
|
||||
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)]
|
||||
leading_group_dim = comm_spec.logical_process_axes[0]
|
||||
assert len(comm_spec.device_mesh.process_groups_dict) == 1
|
||||
@@ -149,7 +146,7 @@ def _mix_gather(tensor, comm_spec):
|
||||
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous()
|
||||
else:
|
||||
mesh_shape = comm_spec.device_meshes.mesh_shape
|
||||
mesh_shape = comm_spec.device_meshes.shape
|
||||
cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
|
||||
tmp_tensor_shape = list(tensor.shape)
|
||||
tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0]
|
||||
@@ -181,9 +178,9 @@ def _mix_split(tensor, comm_spec):
|
||||
# [4, 5, 6, 7]]
|
||||
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
|
||||
'''
|
||||
mesh_shape = comm_spec.device_meshes.mesh_shape
|
||||
mesh_shape = comm_spec.device_meshes.shape
|
||||
dim = comm_spec.gather_dim
|
||||
total_slices = comm_spec.device_mesh.mesh_shape[0]
|
||||
total_slices = comm_spec.device_mesh.shape[0]
|
||||
|
||||
# Get global rank
|
||||
rank = dist.get_rank()
|
||||
@@ -414,7 +411,7 @@ class CommSpec:
|
||||
self.forward_only = forward_only
|
||||
if isinstance(self.logical_process_axis, list):
|
||||
if not mix_gather:
|
||||
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
|
||||
self.device_mesh = self.sharding_spec.device_mesh.flatten()
|
||||
self.logical_process_axis = 0
|
||||
else:
|
||||
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes
|
||||
|
Reference in New Issue
Block a user