mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Tensor] make a simple net works with 1D row TP (#879)
This commit is contained in:
@@ -157,5 +157,21 @@ class ColoTensor(object):
|
||||
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
|
||||
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||
|
||||
## TODO(fjr) we reduce redundency of the following code
|
||||
def __add__(self, o) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
||||
|
||||
def __truediv__(self, o) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
||||
|
||||
def view(self, *args: int) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor().view(*args))
|
||||
|
||||
def permute(self, *args) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor().permute(*args))
|
||||
|
||||
def transpose(self, *args) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor().transpose(*args))
|
||||
|
||||
def contiguous(self):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor().contiguous())
|
||||
|
Reference in New Issue
Block a user