[Tensor] overriding paramters() for Module using ColoTensor (#889)

This commit is contained in:
Jiarui Fang
2022-04-27 15:28:59 +08:00
committed by GitHub
parent daf59ff72e
commit 26c49639d8
3 changed files with 74 additions and 6 deletions

View File

@@ -165,7 +165,12 @@ class ColoTensor(object):
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
def __add__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
if isinstance(o, ColoTensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
elif isinstance(o, torch.Tensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o)
else:
raise TypeError(f'{type(o)} is not supported in ColoTensor __add__')
def __truediv__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)