[pipelinable]use ColoTensor to replace dummy tensor. (#853)

This commit is contained in:
YuliangLiu0306
2022-04-24 18:31:22 +08:00
committed by GitHub
parent bcc8655021
commit c6930d8ddf
2 changed files with 26 additions and 2 deletions

View File

@@ -53,6 +53,22 @@ class ColoTensor(object):
def size(self):
return self._size
@property
def shape(self):
return torch.Size(self._size)
def size(self, dim=None):
if dim is None:
return self.shape
return self._size[dim]
def dim(self):
return len(self._size)
def normal_(self, mean=0., std=1.):
torch_tensor = self.torch_tensor()
return torch_tensor.normal_(mean=mean, std=std)
def numel(self):
return product(self._size)