mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 02:57:20 +00:00
[NFC] polish colossalai/nn/layer/parallel_3d/_operation.py code style (#1258)
Co-authored-by: Research <research@soccf-snr3-017.comp.nus.edu.sg>
This commit is contained in:
parent
9738fb0f78
commit
f660152c73
@ -326,10 +326,8 @@ def split_batch_3d(input_: Tensor,
|
|||||||
|
|
||||||
if input_.size(dim) <= 1:
|
if input_.size(dim) <= 1:
|
||||||
return input_
|
return input_
|
||||||
output = torch.chunk(input_, weight_world_size,
|
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||||
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
||||||
output = torch.chunk(output, input_world_size,
|
|
||||||
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user