mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-01 17:52:05 +00:00
[NFC] polish colossalai/nn/layer/parallel_3d/_operation.py code style (#1258)
Co-authored-by: Research <research@soccf-snr3-017.comp.nus.edu.sg>
This commit is contained in:
parent
9738fb0f78
commit
f660152c73
@ -326,10 +326,8 @@ def split_batch_3d(input_: Tensor,
|
||||
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
output = torch.chunk(input_, weight_world_size,
|
||||
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||
output = torch.chunk(output, input_world_size,
|
||||
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
||||
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
||||
return output
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user