[tensor] revert local view back (#1178)

This commit is contained in:
Jiarui Fang
2022-06-27 18:38:34 +08:00
committed by GitHub
parent 0dd4e2bbfb
commit 1b657f9ce1
4 changed files with 10 additions and 20 deletions

View File

@@ -52,7 +52,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
num_embeddings_per_partition = weight.size_base(0)
num_embeddings_per_partition = weight.size_local(0)
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition