[tensor] ZeRO use ColoTensor as the base class. (#828)

* [refactor] moving InsertPostInitMethodToModuleSubClasses to utils.

* [tensor] ZeRO use ColoTensor as the base class.

* polish
This commit is contained in:
Jiarui Fang
2022-04-22 12:00:48 +08:00
committed by GitHub
parent 8e6fdb4f29
commit 294a6060d0
4 changed files with 36 additions and 26 deletions

View File

@@ -15,12 +15,12 @@ class ColoTensor(object):
return super(ColoTensor, cls).__new__(cls)
def __init__(
self,
*size: Tuple[int],
dtype=None,
requires_grad=False,
pin_memory=False,
torch_tensor=None,
self,
*size: Tuple[int],
dtype=None,
requires_grad=False,
pin_memory=False,
torch_tensor=torch.empty(0),
):
self._size = size
self._dtype = dtype
@@ -37,8 +37,13 @@ class ColoTensor(object):
torch_tensor=tensor)
return colo_t
def del_torch_tensor(self) -> None:
self._size = (0,)
self._torch_tensor = torch.empty(self._size)
def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor == None:
if self._torch_tensor == None or self._torch_tensor.numel() == 0:
print(self._size, type(self._size))
self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype,
requires_grad=self._requires_grad,