mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Tensor] init ColoParameter (#914)
This commit is contained in:
28
colossalai/tensor/colo_parameter.py
Normal file
28
colossalai/tensor/colo_parameter.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from .colo_tensor import ColoTensor
|
||||
from .const import TensorType
|
||||
import torch
|
||||
|
||||
|
||||
class ColoParameter(ColoTensor):
|
||||
r"""A kind of ColoTensor to be considered as a module parameter.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kargs):
|
||||
super().__init__(*args, **kargs)
|
||||
self._type = TensorType.MODEL
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
t = super(ColoParameter, cls).__new__(cls)
|
||||
t._type = TensorType.MODEL
|
||||
return t
|
||||
|
||||
@staticmethod
|
||||
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoParameter':
|
||||
colo_p = ColoParameter(*tensor.size(),
|
||||
dtype=tensor.dtype,
|
||||
requires_grad=tensor.requires_grad,
|
||||
pin_memory=tensor.is_pinned(),
|
||||
device=tensor.device,
|
||||
torch_tensor=tensor if save_payload else torch.empty(0))
|
||||
return colo_p
|
Reference in New Issue
Block a user