mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[dtensor] updated api and doc (#3845)
This commit is contained in:
@@ -24,12 +24,12 @@ class CommSpec:
|
||||
'''
|
||||
Communication spec is used to record the communication action. It converts the communication spec
|
||||
to real action which will be used in runtime. It contains comm_pattern to determine the
|
||||
communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim
|
||||
communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
|
||||
to determine the buffer shape, and logical_process_axis
|
||||
|
||||
Argument:
|
||||
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
|
||||
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
||||
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
|
||||
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
||||
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
||||
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
||||
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
|
||||
@@ -37,7 +37,7 @@ class CommSpec:
|
||||
|
||||
def __init__(self,
|
||||
comm_pattern: CollectiveCommPattern,
|
||||
process_groups_dict: Dict,
|
||||
process_group_dict: Dict,
|
||||
gather_dim: int = None,
|
||||
shard_dim: int = None,
|
||||
logical_process_axis: int = None):
|
||||
@@ -45,7 +45,7 @@ class CommSpec:
|
||||
self.gather_dim = gather_dim
|
||||
self.shard_dim = shard_dim
|
||||
self.logical_process_axis = logical_process_axis
|
||||
self.process_groups_dict = process_groups_dict
|
||||
self.process_group_dict = process_group_dict
|
||||
|
||||
def __repr__(self):
|
||||
res_list = ["CommSpec:("]
|
||||
@@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _split(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, _ in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||
start = length * dist.get_rank(process_group)
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [
|
||||
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
input_tensor_list = [
|
||||
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
|
||||
]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // world_size
|
||||
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
|
||||
'''
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function):
|
||||
@@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function):
|
||||
def forward(ctx, input_, comm_spec):
|
||||
output = _all_to_all(input_, comm_spec)
|
||||
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
|
||||
process_groups_dict=comm_spec.process_groups_dict,
|
||||
process_group_dict=comm_spec.process_group_dict,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis)
|
||||
|
Reference in New Issue
Block a user