mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
reorgnize colotensor directory (#1062)
* reorgnize colotensor directory * polish code
This commit is contained in:
12
colossalai/nn/_ops/_utils.py
Normal file
12
colossalai/nn/_ops/_utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from typing import Union, Optional
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
GeneralTensor = Union[ColoTensor, torch.Tensor]
|
||||
Number = Union[int, float]
|
||||
|
||||
|
||||
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
|
||||
if tensor is not None and not isinstance(tensor, ColoTensor):
|
||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
||||
return tensor
|
Reference in New Issue
Block a user