Files
ColossalAI/colossalai/nn/layer/colossalai_layer/_utils.py
アマデウス 9ee197d0e9 moved env variables to global variables; (#215)
added branch context;
added vocab parallel layers;
moved split_batch from load_batch to tensor parallel embedding layers;
updated gpt model;
updated unit test cases;
fixed few collective communicator bugs
2022-02-15 11:31:13 +08:00

20 lines
751 B
Python

from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_3d._operation import split_batch_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d}
def partition_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, dict):
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
else:
return _parallel_split_batch[tensor_parallel_mode](input_)
else:
return input_