From 0c4c9aa6e09dd249cc1ec911e4d51eb6785194b0 Mon Sep 17 00:00:00 2001 From: BigOneLiXiaoMing <99854690+BigOneLiXiaoMing@users.noreply.github.com> Date: Thu, 8 Sep 2022 16:02:44 +0800 Subject: [PATCH] [NFC] polish colossalai/nn/_ops/embedding.py code style (#1561) --- colossalai/nn/_ops/embedding.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index 2040d83c1..a045f305b 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -111,18 +111,17 @@ def colo_embedding(input_tensor: GeneralTensor, assert isinstance(weight, ColoTensor) input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - if not weight.has_compute_spec(): # No Model Parallel Applied + if not weight.has_compute_spec(): # No Model Parallel Applied assert weight.is_replicate(), 'Invalid weight spec for native embedding op' - return ColoTensor.from_torch_tensor( - tensor=F.embedding(input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse), - spec=ColoTensorSpec(weight.get_process_group())) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse), + spec=ColoTensorSpec(weight.get_process_group())) + elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied if weight.is_shard_1drow(): mode = 'row' elif weight.is_shard_1dcol():