mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[tensor] redirect .data.__get__ to a tensor instance (#1239)
This commit is contained in:
@@ -2,13 +2,24 @@ from .op_wrapper import _COLOSSAL_OPS
|
||||
from .const import TensorType
|
||||
from copy import copy
|
||||
import torch
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
from functools import lru_cache
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor import distspec, ProcessGroup
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from typing import Optional
|
||||
from typing import Optional, Set, Callable
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def _get_my_nowrap_functions() -> Set[Callable]:
|
||||
Tensor = torch.Tensor
|
||||
return {
|
||||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
}
|
||||
|
||||
|
||||
def _convert_output(output, pg: ProcessGroup):
|
||||
@@ -154,7 +165,7 @@ class ColoTensor(torch.Tensor):
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
if func in get_default_nowrap_functions():
|
||||
if func in _get_my_nowrap_functions():
|
||||
return ret
|
||||
else:
|
||||
pg = _scan_for_pg_from_args(args, kwargs)
|
||||
@@ -170,8 +181,9 @@ class ColoTensor(torch.Tensor):
|
||||
Args:
|
||||
dist_spec (_DistSpec): the target dist. spec.
|
||||
"""
|
||||
assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
|
||||
with DistSpecManager.no_grad():
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
@@ -182,8 +194,7 @@ class ColoTensor(torch.Tensor):
|
||||
"""to_replicate_
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, distspec.replicate(), self.process_group)
|
||||
self.dist_spec = distspec.replicate()
|
||||
self._convert_to_dist_spec(dist_spec=distspec.replicate())
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
@@ -223,12 +234,8 @@ class ColoTensor(torch.Tensor):
|
||||
"""
|
||||
if self.is_replicate():
|
||||
return super().view(*args)
|
||||
# TODO(jiaruifang) check why this not work
|
||||
# self.data = self.to_replicate()
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, distspec.replicate(),
|
||||
self.process_group)
|
||||
self.dist_spec = distspec.replicate()
|
||||
return super().view(*args)
|
||||
replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate())
|
||||
return replicated_t.view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
"""override the torch buildin size()
|
||||
|
Reference in New Issue
Block a user