mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[refactor] move process group from _DistSpec to ColoTensor. (#1203)
This commit is contained in:
@@ -6,15 +6,15 @@ import torch.distributed as dist
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.tensor import ProcessGroup, ColoTensorSpec
|
||||
|
||||
GeneralTensor = Union[ColoTensor, torch.Tensor]
|
||||
Number = Union[int, float]
|
||||
|
||||
|
||||
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
|
||||
def convert_to_colo_tensor(tensor: Optional[GeneralTensor], pg: ProcessGroup) -> Optional[ColoTensor]:
|
||||
if tensor is not None and not isinstance(tensor, ColoTensor):
|
||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
||||
tensor = ColoTensor.from_torch_tensor(tensor, ColoTensorSpec(pg))
|
||||
return tensor
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user