[Tensor] make a simple net works with 1D row TP (#879)

This commit is contained in:
Jiarui Fang
2022-04-26 18:11:47 +08:00
committed by GitHub
parent c4d903e64a
commit 7f76517a85
2 changed files with 36 additions and 5 deletions

View File

@@ -157,5 +157,21 @@ class ColoTensor(object):
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
## TODO(fjr) we reduce redundency of the following code
def __add__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
def __truediv__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
def view(self, *args: int) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor().view(*args))
def permute(self, *args) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor().permute(*args))
def transpose(self, *args) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor().transpose(*args))
def contiguous(self):
return ColoTensor.init_from_torch_tensor(self.torch_tensor().contiguous())