[Tensor] make ColoTensor more robust for getattr (#886)

* [Tensor] make ColoTensor more robust for getattr

* polish

* polish
This commit is contained in:
Jiarui Fang
2022-04-27 10:57:49 +08:00
committed by GitHub
parent 9bc5a77c31
commit 72cdc06875
4 changed files with 58 additions and 28 deletions

View File

@@ -152,18 +152,34 @@ class ColoTensor(object):
kwargs = {}
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return cls._filter_outputs_with_colo(func(*args,**kwargs))
return cls._filter_outputs_with_colo(func(*args, **kwargs))
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
def __add__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
def __truediv__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
def __getattr__(self, name):
def replace_tensor_with_colo(func):
def execute_func(*args, **kwargs):
return self._filter_outputs_with_colo(func(*args,**kwargs))
# transform the ColoTensor args to torch Tensor.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
if kwargs is None:
kwargs = {}
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return self._filter_outputs_with_colo(func(*args, **kwargs))
return execute_func
assert hasattr(self._torch_tensor, name), f"torch.Tensor has not attribute named as {name}. So is ColoTensor"
attr = getattr(self._torch_tensor, name)
if isinstance(attr, Callable):
return replace_tensor_with_colo(attr)
else:
@@ -171,9 +187,12 @@ class ColoTensor(object):
@classmethod
def _filter_outputs_with_colo(cls, outputs):
if outputs is None: # return None
if outputs is None: # return None
return None
elif type(outputs) is not tuple: # num of return val = 1
elif type(outputs) is not tuple: # num of return val = 1
return ColoTensor.init_from_torch_tensor(outputs) if type(outputs) is torch.Tensor else outputs
else: # num of return val > 1
return tuple([ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output for output in outputs])
else: # num of return val > 1
return tuple([
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
for output in outputs
])