mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[tensor] redistribute among different process groups (#1247)
* make it faster * [tensor] rename convert_to_dist -> redistribute * [tensor] ShardSpec and ReplicaSpec * [tensor] redistribute among diff pgs * polish code
This commit is contained in:
@@ -13,7 +13,6 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||
|
||||
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
|
||||
|
||||
|
||||
# Output:P
|
||||
partial_output = torch.mm(mat1, mat2)
|
||||
# Reduce(Output)
|
||||
|
@@ -14,7 +14,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
output_parallel = F.embedding(input_tensor,
|
||||
@@ -47,7 +46,6 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
pg = weight.get_process_group()
|
||||
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
@@ -32,9 +32,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||
# All-Gather(Output)
|
||||
# Input:B
|
||||
compute_spec = weight.compute_spec
|
||||
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
||||
|
||||
output_parallel = F.linear(input_parallel, weight, bias)
|
||||
|
@@ -186,7 +186,28 @@ class ColoTensor(torch.Tensor):
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
||||
"""redistribute
|
||||
Redistribute the tensor among processes. The rule is like this:
|
||||
1. If the pg is None, then redistributed tensor payload among TP process group. Keep the
|
||||
DP process group still as replicated.
|
||||
2. If the pg is not not None and not equal to the cureent process group.
|
||||
First, convert the tensor as replicated among TP process group.
|
||||
Second, reset the process group.
|
||||
Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): the new dist spec.
|
||||
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a redistributed colotensor
|
||||
"""
|
||||
if pg is not None and pg != self.get_process_group():
|
||||
print('here _redistribute')
|
||||
# if the pg is not equal, convert the current tensor to replicated
|
||||
self._redistribute(ReplicaSpec())
|
||||
self.process_group = pg
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
||||
|
||||
@@ -202,7 +223,6 @@ class ColoTensor(torch.Tensor):
|
||||
"""
|
||||
return self.redistribute(ReplicaSpec())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||
tensor = tensor.as_subclass(ColoTensor)
|
||||
|
@@ -117,13 +117,13 @@ class ProcessGroup:
|
||||
if not isinstance(obj, ProcessGroup):
|
||||
return False
|
||||
if self._rank != obj._rank:
|
||||
assert False
|
||||
return False
|
||||
if self._rank_list != obj._rank_list:
|
||||
assert False
|
||||
return False
|
||||
if self._tp_rank_list != obj._tp_rank_list:
|
||||
assert False
|
||||
return False
|
||||
if self._dp_rank_list != obj._dp_rank_list:
|
||||
assert False
|
||||
return False
|
||||
if self._tp_degree != obj._tp_degree:
|
||||
return False
|
||||
if self._dp_degree != obj._dp_degree:
|
||||
|
Reference in New Issue
Block a user