[hotfix] the bug of numel() in ColoTensor (#845)

This commit is contained in:
Jiarui Fang
2022-04-24 12:32:10 +08:00
committed by GitHub
parent c1e8d2001e
commit ea0a2ed25f
2 changed files with 21 additions and 6 deletions

View File

@@ -1,6 +1,8 @@
from numpy import product
import torch
from .op_wrapper import _COLOSSAL_OPS
from typing import Tuple
import numpy
from .op_wrapper import _COLOSSAL_OPS
class ColoTensor(object):
@@ -31,7 +33,7 @@ class ColoTensor(object):
self._torch_tensor = torch_tensor
def numel(self):
return sum(self._size)
return product(self._size)
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
@@ -44,9 +46,17 @@ class ColoTensor(object):
return colo_t
def del_torch_tensor(self, save_shape=False) -> None:
if save_shape:
"""
delete the payload of the torch tensor.
Args:
save_shape (bool, optional): if saving the shape of the torch_tensor.
If saving the shape, the size of self._torch_tensor is inconsist with the self._size.
Defaults to False.
"""
if not save_shape:
self._size = (0,)
self._torch_tensor = torch.empty((0,))
self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype)
def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor.numel() == 0: