mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[tensor] refactor colo-tensor (#992)
* refactor colo-tensor and update linear op * polish code * polish code * update ops and unit tests * update unit tests * polish code * rename dist_spec module * polish code * polish code * remove unneeded import * fix pipelinable
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from .colo_tensor import ColoTensor
|
||||
from .const import TensorType
|
||||
import torch
|
||||
from colossalai.tensor import TensorSpec, distspec
|
||||
from copy import copy
|
||||
|
||||
|
||||
class ColoParameter(ColoTensor):
|
||||
@@ -8,21 +10,26 @@ class ColoParameter(ColoTensor):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kargs):
|
||||
super().__init__(*args, **kargs)
|
||||
self._type = TensorType.MODEL
|
||||
def __new__(cls,
|
||||
data: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
t = super(ColoParameter, cls).__new__(cls)
|
||||
t._type = TensorType.MODEL
|
||||
return t
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._spec = copy(spec)
|
||||
self._type = TensorType.MODEL
|
||||
self._graph_node = None
|
||||
|
||||
@staticmethod
|
||||
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoParameter':
|
||||
colo_p = ColoParameter(*tensor.size(),
|
||||
dtype=tensor.dtype,
|
||||
requires_grad=tensor.requires_grad,
|
||||
pin_memory=tensor.is_pinned(),
|
||||
device=tensor.device,
|
||||
torch_tensor=tensor if save_payload else torch.empty(0))
|
||||
return colo_p
|
||||
def from_torch_tensor(tensor: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
|
||||
tensor = tensor.as_subclass(ColoParameter)
|
||||
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
|
||||
return tensor
|
||||
|
Reference in New Issue
Block a user