mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-22 23:32:37 +00:00
* refactor colo-tensor and update linear op * polish code * polish code * update ops and unit tests * update unit tests * polish code * rename dist_spec module * polish code * polish code * remove unneeded import * fix pipelinable
13 lines
390 B
Python
13 lines
390 B
Python
import torch
|
|
from typing import Union, Optional
|
|
from colossalai.tensor import ColoTensor
|
|
|
|
GeneralTensor = Union[ColoTensor, torch.Tensor]
|
|
Number = Union[int, float]
|
|
|
|
|
|
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
|
|
if tensor is not None and not isinstance(tensor, ColoTensor):
|
|
tensor = ColoTensor.from_torch_tensor(tensor)
|
|
return tensor
|