[tensor] redistribute among different process groups (#1247)

* make it faster

* [tensor] rename convert_to_dist -> redistribute

* [tensor] ShardSpec and ReplicaSpec

* [tensor] redistribute among diff pgs

* polish code
This commit is contained in:
Jiarui Fang
2022-07-12 10:24:05 +08:00
committed by GitHub
parent 9bcd2fd4af
commit 1aad903c15
8 changed files with 48 additions and 17 deletions

View File

@@ -13,7 +13,6 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
# Output:P
partial_output = torch.mm(mat1, mat2)
# Reduce(Output)

View File

@@ -14,7 +14,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding(input_tensor,
@@ -47,7 +46,6 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Find index in this shard and mask those not here
# Reduce all
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(ReplicaSpec())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)

View File

@@ -32,9 +32,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Gather(Output)
# Input:B
compute_spec = weight.compute_spec
input_tensor = input_tensor.redistribute(ReplicaSpec())
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias)

View File

@@ -186,7 +186,28 @@ class ColoTensor(torch.Tensor):
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
"""redistribute
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistributed tensor payload among TP process group. Keep the
DP process group still as replicated.
2. If the pg is not not None and not equal to the cureent process group.
First, convert the tensor as replicated among TP process group.
Second, reset the process group.
Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec.
Args:
dist_spec (_DistSpec): the new dist spec.
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
Returns:
ColoTensor: a redistributed colotensor
"""
if pg is not None and pg != self.get_process_group():
print('here _redistribute')
# if the pg is not equal, convert the current tensor to replicated
self._redistribute(ReplicaSpec())
self.process_group = pg
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
@@ -202,7 +223,6 @@ class ColoTensor(torch.Tensor):
"""
return self.redistribute(ReplicaSpec())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
tensor = tensor.as_subclass(ColoTensor)

View File

@@ -117,13 +117,13 @@ class ProcessGroup:
if not isinstance(obj, ProcessGroup):
return False
if self._rank != obj._rank:
assert False
return False
if self._rank_list != obj._rank_list:
assert False
return False
if self._tp_rank_list != obj._tp_rank_list:
assert False
return False
if self._dp_rank_list != obj._dp_rank_list:
assert False
return False
if self._tp_degree != obj._tp_degree:
return False
if self._dp_degree != obj._dp_degree: