[tensor] redistribute among different process groups (#1247)

* make it faster

* [tensor] rename convert_to_dist -> redistribute

* [tensor] ShardSpec and ReplicaSpec

* [tensor] redistribute among diff pgs

* polish code
This commit is contained in:
Jiarui Fang
2022-07-12 10:24:05 +08:00
committed by GitHub
parent 9bcd2fd4af
commit 1aad903c15
8 changed files with 48 additions and 17 deletions

View File

@@ -13,7 +13,6 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
# Output:P
partial_output = torch.mm(mat1, mat2)
# Reduce(Output)

View File

@@ -14,7 +14,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding(input_tensor,
@@ -47,7 +46,6 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Find index in this shard and mask those not here
# Reduce all
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(ReplicaSpec())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)

View File

@@ -32,9 +32,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Gather(Output)
# Input:B
compute_spec = weight.compute_spec
input_tensor = input_tensor.redistribute(ReplicaSpec())
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias)