mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-01 22:39:42 +00:00
[tensor] refactor colo-tensor (#992)
* refactor colo-tensor and update linear op * polish code * polish code * update ops and unit tests * update unit tests * polish code * rename dist_spec module * polish code * polish code * remove unneeded import * fix pipelinable
This commit is contained in:
12
colossalai/tensor/_ops/_utils.py
Normal file
12
colossalai/tensor/_ops/_utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from typing import Union, Optional
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
GeneralTensor = Union[ColoTensor, torch.Tensor]
|
||||
Number = Union[int, float]
|
||||
|
||||
|
||||
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
|
||||
if tensor is not None and not isinstance(tensor, ColoTensor):
|
||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
||||
return tensor
|
||||
Reference in New Issue
Block a user