From 595bedf7678d1d157e7a8e8230e266ddf21a26d3 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 22 Apr 2022 12:12:35 +0800 Subject: [PATCH] revert zero tensors back (#829) --- .../zero/sharded_param/sharded_tensor.py | 4 +-- .../zero/sharded_param/tensorful_state.py | 35 ++++++++----------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py index e1f48b318..fde273320 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -20,8 +20,8 @@ class ShardedTensor(StatefulTensor): @property def dtype(self) -> torch.dtype: - assert self.torch_tensor().dtype == self._origin_dtype - return self.torch_tensor().dtype + assert self._payload.dtype == self._origin_dtype + return self._payload.dtype @property def origin_numel(self) -> int: diff --git a/colossalai/zero/sharded_param/tensorful_state.py b/colossalai/zero/sharded_param/tensorful_state.py index d62f85b0e..a108963e5 100644 --- a/colossalai/zero/sharded_param/tensorful_state.py +++ b/colossalai/zero/sharded_param/tensorful_state.py @@ -1,7 +1,6 @@ from enum import Enum from typing import Optional import torch -from colossalai.tensor import ColoTensor class TensorState(Enum): @@ -12,7 +11,7 @@ class TensorState(Enum): COMPUTE = 4 -class StatefulTensor(ColoTensor): +class StatefulTensor(object): """A Structure stores a Torch Tensor and labeled states. Inspired from the paper: PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management @@ -21,20 +20,15 @@ class StatefulTensor(ColoTensor): """ def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: - if tensor is not None: - super().__init__(tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, \ - pin_memory=tensor.pin_memory, torch_tensor=tensor) - else: - super().__init__(0) - self._state = state + self._payload = tensor if self._state == TensorState.FREE: - assert self.torch_tensor().numel() == 0, f"payload has to None if state is {self._state}" + assert self._payload is None, f"payload has to None if state is {self._state}" def data_ptr(self): - if self.torch_tensor().numel() == 0: + if self._payload is None: return None - return self.torch_tensor().data_ptr() + return self._payload.data_ptr() @property def state(self) -> TensorState: @@ -42,41 +36,42 @@ class StatefulTensor(ColoTensor): def set_null(self) -> None: self._state = TensorState.FREE - self.del_torch_tensor() + self._payload = None def is_null(self) -> bool: if self._state == TensorState.FREE: - assert self.torch_tensor().numel() == 0 + assert self._payload is None return True return False def trans_state(self, state: TensorState) -> None: self._state = state if state == TensorState.FREE: - self.del_torch_tensor() + self._payload = None @property def payload(self) -> Optional[torch.Tensor]: - return self.torch_tensor() + return self._payload def copy_payload(self, tensor) -> None: - self.torch_tensor.view(-1).copy_(tensor.view(-1)) + self._payload.view(-1).copy_(tensor.view(-1)) def reset_payload(self, tensor) -> None: - self._torch_tensor = tensor + del self._payload + self._payload = tensor self.trans_state(TensorState.HOLD) @property def device(self) -> torch.device: - return self.torch_tensor().device + return self._payload.device @property def dtype(self) -> torch.dtype: - return self.torch_tensor().dtype + return self._payload.dtype @property def shape(self): - return self.torch_tensor().shape + return self._payload.shape def to(self, device: torch.device): raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")