[refactor] move process group from _DistSpec to ColoTensor. (#1203)

This commit is contained in:
Jiarui Fang
2022-07-06 16:15:16 +08:00
committed by GitHub
parent 5da87ce35d
commit ae7d3f4927
34 changed files with 452 additions and 367 deletions

View File

@@ -6,15 +6,15 @@ import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn.layer.utils import divide
from colossalai.tensor import ProcessGroup
from colossalai.tensor import ProcessGroup, ColoTensorSpec
GeneralTensor = Union[ColoTensor, torch.Tensor]
Number = Union[int, float]
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
def convert_to_colo_tensor(tensor: Optional[GeneralTensor], pg: ProcessGroup) -> Optional[ColoTensor]:
if tensor is not None and not isinstance(tensor, ColoTensor):
tensor = ColoTensor.from_torch_tensor(tensor)
tensor = ColoTensor.from_torch_tensor(tensor, ColoTensorSpec(pg))
return tensor