mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[tensor] ZeRO use ColoTensor as the base class. (#828)
* [refactor] moving InsertPostInitMethodToModuleSubClasses to utils. * [tensor] ZeRO use ColoTensor as the base class. * polish
This commit is contained in:
@@ -15,12 +15,12 @@ class ColoTensor(object):
|
||||
return super(ColoTensor, cls).__new__(cls)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*size: Tuple[int],
|
||||
dtype=None,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
torch_tensor=None,
|
||||
self,
|
||||
*size: Tuple[int],
|
||||
dtype=None,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
torch_tensor=torch.empty(0),
|
||||
):
|
||||
self._size = size
|
||||
self._dtype = dtype
|
||||
@@ -37,8 +37,13 @@ class ColoTensor(object):
|
||||
torch_tensor=tensor)
|
||||
return colo_t
|
||||
|
||||
def del_torch_tensor(self) -> None:
|
||||
self._size = (0,)
|
||||
self._torch_tensor = torch.empty(self._size)
|
||||
|
||||
def torch_tensor(self) -> torch.Tensor:
|
||||
if self._torch_tensor == None:
|
||||
if self._torch_tensor == None or self._torch_tensor.numel() == 0:
|
||||
print(self._size, type(self._size))
|
||||
self._torch_tensor = torch.empty(*self._size,
|
||||
dtype=self._dtype,
|
||||
requires_grad=self._requires_grad,
|
||||
|
Reference in New Issue
Block a user