mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Tensor] overriding paramters() for Module using ColoTensor (#889)
This commit is contained in:
@@ -165,7 +165,12 @@ class ColoTensor(object):
|
||||
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||
|
||||
def __add__(self, o) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
||||
if isinstance(o, ColoTensor):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
||||
elif isinstance(o, torch.Tensor):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o)
|
||||
else:
|
||||
raise TypeError(f'{type(o)} is not supported in ColoTensor __add__')
|
||||
|
||||
def __truediv__(self, o) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
||||
|
Reference in New Issue
Block a user