mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[hotfix] the bug of numel() in ColoTensor (#845)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from numpy import product
|
||||
import torch
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
from typing import Tuple
|
||||
import numpy
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
|
||||
class ColoTensor(object):
|
||||
@@ -31,7 +33,7 @@ class ColoTensor(object):
|
||||
self._torch_tensor = torch_tensor
|
||||
|
||||
def numel(self):
|
||||
return sum(self._size)
|
||||
return product(self._size)
|
||||
|
||||
@staticmethod
|
||||
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
|
||||
@@ -44,9 +46,17 @@ class ColoTensor(object):
|
||||
return colo_t
|
||||
|
||||
def del_torch_tensor(self, save_shape=False) -> None:
|
||||
if save_shape:
|
||||
"""
|
||||
delete the payload of the torch tensor.
|
||||
|
||||
Args:
|
||||
save_shape (bool, optional): if saving the shape of the torch_tensor.
|
||||
If saving the shape, the size of self._torch_tensor is inconsist with the self._size.
|
||||
Defaults to False.
|
||||
"""
|
||||
if not save_shape:
|
||||
self._size = (0,)
|
||||
self._torch_tensor = torch.empty((0,))
|
||||
self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype)
|
||||
|
||||
def torch_tensor(self) -> torch.Tensor:
|
||||
if self._torch_tensor.numel() == 0:
|
||||
|
Reference in New Issue
Block a user