mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[refactor] move process group from _DistSpec to ColoTensor. (#1203)
This commit is contained in:
@@ -5,7 +5,7 @@ from copy import copy
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.const import TensorType
|
||||
from colossalai.tensor import TensorSpec, distspec
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
def __new__(cls,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
|
||||
spec: ColoTensorSpec = None) -> 'ColoParameter':
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
@@ -36,11 +36,9 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
def __init__(self,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._tensor_spec = copy(spec)
|
||||
spec: ColoTensorSpec = None) -> None:
|
||||
ColoTensor.__init__(self, data, spec)
|
||||
self._type = TensorType.MODEL
|
||||
self._graph_node = None
|
||||
|
||||
# a list contains modules sharing this ColoParameter with others.
|
||||
self._shared_param_modules = []
|
||||
|
||||
@@ -51,7 +49,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
|
||||
spec: ColoTensorSpec = None) -> 'ColoParameter':
|
||||
tensor = tensor.as_subclass(ColoParameter)
|
||||
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
|
||||
return tensor
|
||||
@@ -82,7 +80,9 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.tensor_spec))
|
||||
tensor = ColoParameter(data,
|
||||
self.requires_grad,
|
||||
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
|
Reference in New Issue
Block a user