mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-16 23:16:56 +00:00
[NFC] polish colossalai/nn/_ops/embedding.py code style (#1561)
This commit is contained in:
parent
08815f0e72
commit
0c4c9aa6e0
@ -113,8 +113,7 @@ def colo_embedding(input_tensor: GeneralTensor,
|
|||||||
|
|
||||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||||
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
|
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
|
||||||
return ColoTensor.from_torch_tensor(
|
return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor,
|
||||||
tensor=F.embedding(input_tensor,
|
|
||||||
weight,
|
weight,
|
||||||
padding_idx=padding_idx,
|
padding_idx=padding_idx,
|
||||||
max_norm=max_norm,
|
max_norm=max_norm,
|
||||||
|
Loading…
Reference in New Issue
Block a user