mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-18 17:31:53 +00:00
embedding op use gather_out (#1143)
This commit is contained in:
parent
e61dc31b05
commit
ccf3c58c89
@ -30,6 +30,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
if weight.spec.parallel_action.gather_out:
|
||||
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
return output
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user