From e5ea3fdeefc674ae16021855752fa767053fc0d9 Mon Sep 17 00:00:00 2001 From: HELSON Date: Sun, 24 Apr 2022 13:08:48 +0800 Subject: [PATCH] [gemini] add GeminiMemoryManger (#832) * refactor StatefulTensor, tensor utilities * add unitest for GeminiMemoryManager --- colossalai/gemini/gemini_context.py | 45 ++++ colossalai/gemini/stateful_tensor.py | 204 ++++++++++++++++++ colossalai/gemini/stateful_tensor_mgr.py | 8 +- colossalai/gemini/tensor_placement_policy.py | 4 +- .../sharded_param => gemini}/tensor_utils.py | 75 +++---- colossalai/nn/layer/moe/utils.py | 2 +- colossalai/zero/__init__.py | 2 +- colossalai/zero/init_ctx/init_context.py | 5 +- .../bucket_tensor_shard_strategy.py | 5 +- .../zero/shard_utils/tensor_shard_strategy.py | 6 +- colossalai/zero/sharded_model/_utils.py | 2 +- .../zero/sharded_model/sharded_model_v2.py | 13 +- .../zero/sharded_optim/sharded_optim_v2.py | 27 ++- colossalai/zero/sharded_param/__init__.py | 8 +- .../zero/sharded_param/sharded_param.py | 15 +- .../zero/sharded_param/sharded_tensor.py | 3 +- .../zero/sharded_param/tensorful_state.py | 80 ------- colossalai/zero/utils/zero_hook.py | 3 +- tests/test_gemini/test_gemini_manager.py | 73 +++++++ tests/test_gemini/test_stateful_tensor_mgr.py | 3 +- tests/test_utils/test_commons.py | 2 +- tests/test_zero/test_shard_param.py | 2 +- tests/test_zero/test_tensor_utils.py | 7 +- 23 files changed, 414 insertions(+), 180 deletions(-) create mode 100644 colossalai/gemini/gemini_context.py create mode 100644 colossalai/gemini/stateful_tensor.py rename colossalai/{zero/sharded_param => gemini}/tensor_utils.py (64%) delete mode 100644 colossalai/zero/sharded_param/tensorful_state.py create mode 100644 tests/test_gemini/test_gemini_manager.py diff --git a/colossalai/gemini/gemini_context.py b/colossalai/gemini/gemini_context.py new file mode 100644 index 000000000..aeade031f --- /dev/null +++ b/colossalai/gemini/gemini_context.py @@ -0,0 +1,45 @@ +from enum import EnumMeta + + +class GeminiMemoryManager(object): + + def __init__(self, states_cls: EnumMeta): + super().__init__() + self.states_cls = states_cls + self._cnter = 0 # the counter of instances + + self.total_mem = dict() + self.state_mem = dict() + self.state_mem['cpu'] = dict() + self.state_mem['cuda'] = dict() + + self.reset() + + @property + def total_number(self): + return self._cnter + + def reset(self): + self._cnter = 0 # the counter of instances + + self.total_mem['cpu'] = 0 # memory occupation of instances in cpu + self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda + + # memory conditions for all states + for state in self.states_cls: + self.state_mem['cpu'][state] = 0 + self.state_mem['cuda'][state] = 0 + + def register_new_instance(self): + self._cnter += 1 + + def print_info(self): + print( + f"Total number: {self.total_number}", + f"Total CPU memory occupation: {self.total_mem['cpu']}", + f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n') + + for state in self.states_cls: + print( + f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", + f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n') diff --git a/colossalai/gemini/stateful_tensor.py b/colossalai/gemini/stateful_tensor.py new file mode 100644 index 000000000..d6ab29cbe --- /dev/null +++ b/colossalai/gemini/stateful_tensor.py @@ -0,0 +1,204 @@ +from enum import Enum +from typing import Optional +import torch +from typing import Union + +from colossalai.gemini.gemini_context import GeminiMemoryManager + + +def sizeof_tensor(tensor: torch.Tensor): + return tensor.numel() * tensor.element_size() + + +class TensorState(Enum): + FREE = 0 + HOLD = 1 + HOLD_AFTER_FWD = 2 + HOLD_AFTER_BWD = 3 + COMPUTE = 4 + + +class StatefulTensor(object): + """A Structure stores a Torch Tensor and labeled states. + Inspired from the paper: + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management + + https://arxiv.org/abs/2108.05818 + """ + # Global Stateful Tensor Manager + GST_MGR = GeminiMemoryManager(TensorState) + + def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: + self._state = state + self._payload = None + self._payload_size = 0 # byte size of current payload + + StatefulTensor.GST_MGR.register_new_instance() + + if self._state == TensorState.FREE: + # when the state is free, payload should be None + assert maybe_tensor is None, f"payload has to None if state is {self._state}" + else: + # otherwise, payload should not be None + assert maybe_tensor is not None, f"payload can't be None if state is {self._state}" + self._payload = maybe_tensor + self._payload_size = sizeof_tensor(maybe_tensor) + self.__trans_state_update(TensorState.FREE, state) + + def data_ptr(self): + if self._payload is None: + return 0 # if a tensor has no storage, 0 should be returned + return self._payload.data_ptr() + + def set_null(self) -> None: + # notice that free stateful tensor do not need to become null again + if self.state != TensorState.FREE: + self.__trans_state_update(self.state, TensorState.FREE) + self.__release() + + def is_null(self) -> bool: + if self.state == TensorState.FREE: + # check sanity here + assert self.payload is None + return True + return False + + def trans_state(self, state: TensorState) -> None: + if self.state == TensorState.FREE: + # free stateful tensor can't change state + assert state == TensorState.FREE, "Free stateful tensor can't change to other states" + return + + self.__trans_state_update(self.state, state) + + if state == TensorState.FREE: + self.__release() + else: + self._state = state + + def move_to(self, device: Union[torch.device, int]): + assert self.state is not TensorState.FREE, "Can't move free stateful tensor" + + if not isinstance(device, torch.device): + to_device = torch.device('cuda', device) + else: + to_device = device + + from_device_type = self.device.type + if from_device_type == to_device.type: + # from device == to device + return + + # update manager's information + self.__trans_device_update(from_device_type, to_device.type) + self.payload.data = self.payload.data.to(to_device) + + def payload_copy(self, tensor) -> None: + self._payload.view(-1).copy_(tensor.view(-1)) + + def payload_reset(self, tensor) -> None: + + assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead" + + if self.payload is not None: + # release old payload + self.__trans_state_update(self.state, TensorState.FREE) + else: + # otherwise, set the state to HOLD for new payload + self._state = TensorState.HOLD + del self._payload + + self._payload = tensor + self._payload_size = sizeof_tensor(tensor) + # record new payload + self.__trans_state_update(TensorState.FREE, self.state) + + def payload_relay(self, rhs): + # relay the payload of rhs to current stateful tensor + # can't support null relay right now + assert not rhs.is_null() + + # now this function only support stateful tensor that has zero-length payload + # because it doesn't require memory manager updating + # you can extend this function by yourself + assert self.payload_size == 0 + + self._payload = rhs.payload + self._payload_size = rhs.payload_size + self._state = TensorState.HOLD + self.__trans_state_update(rhs.state, TensorState.HOLD) + + rhs.__release() + + @property + def payload(self) -> Optional[torch.Tensor]: + return self._payload + + @property + def payload_size(self) -> int: + return self._payload_size + + @property + def state(self) -> TensorState: + return self._state + + @property + def device(self) -> torch.device: + return self._payload.device + + @property + def dtype(self) -> torch.dtype: + return self._payload.dtype + + @property + def shape(self): + return self._payload.shape + + def to(self, device: torch.device): + raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor") + + def to_(self, device: torch.device): + raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor") + + def __release(self): + # release current payload + # shouldn't be visible to users + self._state = TensorState.FREE + self._payload = None + self._payload_size = 0 + + def __trans_state_update(self, from_state: TensorState, to_state: TensorState): + """Update global manager when changing the state of a tensor + """ + manager = StatefulTensor.GST_MGR + size = self.payload_size + device_type = self.device.type + + if from_state != TensorState.FREE: + manager.state_mem[device_type][from_state] -= size + else: + # when from_state is FREE, the tensor is new to manager + # we should add its memory + manager.total_mem[device_type] += size + + if to_state != TensorState.FREE: + manager.state_mem[device_type][to_state] += size + else: + # when to_state is FREE, the tensor will be deleted soon + # we should sub its memory + manager.total_mem[device_type] -= size + + def __trans_device_update(self, from_type: str, to_type: str): + """Update global manager when changing the device of a tensor + """ + manager = StatefulTensor.GST_MGR + size = self.payload_size + state = self.state + + # update aggregated information + manager.total_mem[from_type] -= size + manager.total_mem[to_type] += size + + # update the information of each state + manager.state_mem[from_type][state] -= size + manager.state_mem[to_type][state] += size diff --git a/colossalai/gemini/stateful_tensor_mgr.py b/colossalai/gemini/stateful_tensor_mgr.py index 29a6fa064..db33047db 100644 --- a/colossalai/gemini/stateful_tensor_mgr.py +++ b/colossalai/gemini/stateful_tensor_mgr.py @@ -2,9 +2,8 @@ import functools import torch import types from colossalai.utils.cuda import get_current_device -from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 -from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState -from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy from typing import List from colossalai.logging import get_dist_logger @@ -30,7 +29,8 @@ class StatefulTensorMgr(object): self._cpu_gpu_move_volume = 0 - def register_stateful_param(self, param: ShardedParamV2) -> None: + def register_stateful_param(self, param) -> None: + from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 assert isinstance(param, ShardedParamV2) for t in param.get_payload_tensors(): assert isinstance(t, StatefulTensor) diff --git a/colossalai/gemini/tensor_placement_policy.py b/colossalai/gemini/tensor_placement_policy.py index f0b06ea2f..b35417c7b 100644 --- a/colossalai/gemini/tensor_placement_policy.py +++ b/colossalai/gemini/tensor_placement_policy.py @@ -4,8 +4,8 @@ import torch from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.zero.sharded_param.tensorful_state import StatefulTensor +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.gemini.memory_tracer import MemStatsCollector from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER from typing import Type diff --git a/colossalai/zero/sharded_param/tensor_utils.py b/colossalai/gemini/tensor_utils.py similarity index 64% rename from colossalai/zero/sharded_param/tensor_utils.py rename to colossalai/gemini/tensor_utils.py index 4895becaf..f2d69046e 100644 --- a/colossalai/zero/sharded_param/tensor_utils.py +++ b/colossalai/gemini/tensor_utils.py @@ -1,10 +1,10 @@ import torch -from colossalai.zero.sharded_param.tensorful_state import StatefulTensor +from colossalai.gemini.stateful_tensor import StatefulTensor from typing import Union, Tuple def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]: - if issubclass(type(tensor), StatefulTensor): + if isinstance(tensor, StatefulTensor): t = tensor.payload elif isinstance(tensor, torch.Tensor): t = tensor @@ -24,23 +24,24 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor, torch.Tensor]) -> None: - """ - A colossal API for model data tensor move. + """ + A colossal API for model data tensor move. The src and target tensors could be resident on both CPU and GPU. - + NOTE() The source tensor payload will be removed after this function. - + The function will record the communication volume between CPU and GPU. Args: - t_src (Union[StatefulTensor, torch.Tensor]): source tensor + src_t (Union[StatefulTensor, torch.Tensor]): source tensor tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor """ - if issubclass(type(src_t), StatefulTensor): + if isinstance(src_t, StatefulTensor): src_t_payload = src_t.payload else: src_t_payload = src_t.data src_dev = src_t_payload.device - if issubclass(type(tgt_t), StatefulTensor): + + if isinstance(tgt_t, StatefulTensor): tgt_t_payload = tgt_t.payload else: tgt_t_payload = tgt_t.data @@ -48,70 +49,56 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_ tgt_t_payload.copy_(src_t_payload) # remove payload of src_t - if issubclass(type(src_t), StatefulTensor): - src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)) + if isinstance(src_t, StatefulTensor): + src_t.set_null() else: - src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype) + src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype) def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device, int]) -> None: - """ + """ move a tensor to the target_device Args: t (Union[StatefulTensor, torch.Tensor]): the tensor be moved target_device: a traget device, if type is int, it the index of cuda card. """ - if isinstance(t, torch.Tensor): - t_payload = t - elif issubclass(type(t), StatefulTensor): - t_payload = t.payload - else: - raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}') - if not isinstance(target_device, torch.device): target_device = torch.device(f'cuda:{target_device}') - # deal with torch.device('cpu') and torch.device('cpu:0) - if t_payload.device.type == target_device.type: - return - t_payload.data = t_payload.data.to(target_device) + if isinstance(t, torch.Tensor): + t.data = t.data.to(target_device) + elif isinstance(t, StatefulTensor): + t.move_to(target_device) + else: + raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}') def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None: - """colo_model_data_move_to_cpu - + """colo_model_data_move_to_cpu move a model data tensor from gpu to cpu - Args: t (Union[StatefulTensor, torch.Tensor]): _description_ """ - - if issubclass(type(t), StatefulTensor): - t_payload = t.payload - elif isinstance(t, torch.Tensor): - t_payload = t - else: - raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}') - - if t_payload.device.type == 'cpu': - return - # TODO() optimize the tensor moving with non-blocking - t_payload.data = t_payload.data.cpu() + if isinstance(t, torch.Tensor): + t.data = t.data.cpu() + elif isinstance(t, StatefulTensor): + t.move_to(torch.device('cpu')) + else: + raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}') def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor: """ Clone a model data tensor - Args: t (Union[StatefulTensor, torch.Tensor]): a model data tensor target_device (torch.device): the target device Returns: torch.Tensor: a cloned torch tensor """ - t_payload = t.payload if issubclass(type(t), StatefulTensor) else t - - ret = t_payload.to(target_device) - return ret + # TODO() rename this function + colo_model_data_tensor_move_inline(t, target_device) + t_payload = t.payload if isinstance(t, StatefulTensor) else t + return t_payload diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index fd985146a..936234741 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -8,7 +8,7 @@ from .experts import FFNExperts, TPExperts class ForceFP32Parameter(torch.nn.Parameter): def half(self, memory_format=None): - return self.data + return self.data.clone() class NormalNoiseGenerator: diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 1ea7c73e3..913a56801 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -35,4 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model return zero_model, zero_optimizer -__all__ = ['convert_to_zerov2', 'ShardedModelV2', 'ShardedOptimizerV2'] +__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2'] diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index b26e82919..22418f33d 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -184,11 +184,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if param.grad is not None: param.grad = param.grad.to(target_device) - param.colo_attr = ShardedParamV2(param, set_data_none=False) + param.colo_attr = ShardedParamV2(param, set_data_none=True) if self.shard_param: self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) - param.data = param.colo_attr.data_payload # set param.data to payload + + param.data = param.colo_attr.data_payload # set param.data to payload # mark whether the param is replicated param.colo_attr.is_replicated = self.is_replicated diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py index 76fea3ff8..a7bd7cf53 100644 --- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -31,9 +31,6 @@ class BucketTensorShardStrategy(TensorShardStrategy): for i in range(world_size): if i == rank: buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) - # Release payload here, to decrease peak memory usage - for t in tensor_list: - t.reset_payload(None) else: buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) dist.all_gather(buffer_list, buffer_list[rank], group=process_group) @@ -44,6 +41,6 @@ class BucketTensorShardStrategy(TensorShardStrategy): for i, t in enumerate(tensor_list): gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list] gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape) - t.reset_payload(gathered_payload) + t.payload_reset(gathered_payload) t.is_sharded = False offset += tensor_numels[i] diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 8857d7ae4..5bdd95400 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -3,10 +3,10 @@ from typing import List, Optional import torch import torch.distributed as dist from colossalai.utils import get_current_device -from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils.commons import get_shard from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline class TensorShardStrategy(BaseShardStrategy): @@ -36,7 +36,7 @@ class TensorShardStrategy(BaseShardStrategy): assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ f" but current cuda device is {get_current_device()}" sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) - t.reset_payload(sharded_payload) + t.payload_reset(sharded_payload) t.is_sharded = True def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): @@ -53,6 +53,6 @@ class TensorShardStrategy(BaseShardStrategy): dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False) gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape) - t.reset_payload(gathered_payload) + t.payload_reset(gathered_payload) colo_model_data_tensor_move_inline(t, target_device) t.is_sharded = False diff --git a/colossalai/zero/sharded_model/_utils.py b/colossalai/zero/sharded_model/_utils.py index eed0ff964..85a3ab73d 100644 --- a/colossalai/zero/sharded_model/_utils.py +++ b/colossalai/zero/sharded_model/_utils.py @@ -3,7 +3,7 @@ from typing import Any, Callable, List, Tuple import torch import torch.nn.functional as F from typing import Union -from colossalai.zero.sharded_param.tensorful_state import StatefulTensor +from colossalai.gemini.stateful_tensor import StatefulTensor def get_gradient_predivide_factor(world_size: int) -> float: diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index d4940b1fa..7dd1ec3e3 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -17,11 +17,11 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer -from colossalai.zero.sharded_param.tensorful_state import TensorState from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu +from colossalai.gemini.stateful_tensor import TensorState from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy @@ -358,8 +358,11 @@ class ShardedModelV2(nn.Module): assert param.colo_attr.saved_grad.is_null( ), 'Gradien accumulation is not supported when reuse_fp16_shard=True' - param.colo_attr.reset_grad_payload(grad.data) - param.colo_attr.reset_data_payload(grad.data) # release the memory of param + param.colo_attr.grad_payload_reset(grad.data) + # release the memory of param + # we set a false None for parameter's payload + # so we can get paramter's device and dtype later in optimizer + param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype)) if param.colo_attr.is_replicated: param.colo_attr.sharded_data_tensor.is_sharded = True @@ -368,7 +371,7 @@ class ShardedModelV2(nn.Module): fp32_grad = cast_tensor_to_fp32(grad) if param.colo_attr.saved_grad.is_null(): - param.colo_attr.reset_grad_payload(fp32_grad) + param.colo_attr.grad_payload_reset(fp32_grad) else: param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload)) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 9f6ee7e03..2a4a69e41 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -12,15 +12,15 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.gemini.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER -from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone, - colo_tensor_mem_usage) +from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone, + colo_tensor_mem_usage) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 -from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState) from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from torch.optim import Optimizer +from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState) from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy @@ -253,7 +253,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): for p in group['params']: # p.colo_attr.sharded_data_tensor stores grad now # we have to recover fp16 param - reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr() + reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0) if recover_data and reuse_fp16_shard: self._copy_master_param_to_param_fp16(p) else: @@ -332,12 +332,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def _copy_master_param_to_param_fp16(self, p): # flush gradient - p.colo_attr.saved_grad.set_null() + if p.colo_attr.sharded_data_tensor.payload_size == 0: + # here reuse_fp16_shard is True + # in order to use copy below, we should give sharded data tensor a payload + p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad) + else: + p.colo_attr.saved_grad.set_null() + + p.data = self.master_params[p].payload + + # we need to allocate new memory for keep_not_shard paramters + # in order to use copy, otherwise, the sizes of tensor is not compatible + if p.colo_attr.data_payload.numel() != p.data.numel(): + p.colo_attr.data_payload_reset( + torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)) # TODO() optimize this line CPU (fp32) -> GPU (fp16) - p.data = self.master_params[p].payload - p.colo_attr.reset_data_payload( - colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device)) + p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach()) p.colo_attr.set_data_none() if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py index f6f46db8e..98544c381 100644 --- a/colossalai/zero/sharded_param/__init__.py +++ b/colossalai/zero/sharded_param/__init__.py @@ -1,11 +1,5 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 -from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline, - colo_model_data_move_to_cpu, colo_model_tensor_clone, - colo_tensor_mem_usage) -from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor __all__ = [ - 'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline', - 'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor' -] + 'ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 72b88ec2f..db0f2d149 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -1,8 +1,8 @@ import torch -from colossalai.zero.sharded_param import ShardedTensor from typing import Optional, Tuple -from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage -from .tensorful_state import StatefulTensor, TensorState +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from colossalai.gemini.tensor_utils import colo_tensor_mem_usage +from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState from typing import List EMPTY_TENSOR_DICT = {} @@ -50,6 +50,7 @@ class ShardedParamV2(object): @property def data_payload(self): + assert not self.sharded_data_tensor.is_null() return self.sharded_data_tensor.payload @property @@ -61,15 +62,15 @@ class ShardedParamV2(object): def param_is_sharded(self): return self.sharded_data_tensor.is_sharded - def reset_data_payload(self, tensor: torch.Tensor): + def data_payload_reset(self, tensor: torch.Tensor): assert type(tensor) is torch.Tensor assert tensor.requires_grad is False - self.sharded_data_tensor.reset_payload(tensor) + self.sharded_data_tensor.payload_reset(tensor) - def reset_grad_payload(self, tensor: torch.Tensor): + def grad_payload_reset(self, tensor: torch.Tensor): assert type(tensor) is torch.Tensor assert tensor.requires_grad is False - self.saved_grad.reset_payload(tensor) + self.saved_grad.payload_reset(tensor) def get_memory_usage(self) -> Tuple[int, int]: """ diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py index fde273320..77f4aec30 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -1,6 +1,5 @@ import torch -from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState -from typing import Optional +from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState class ShardedTensor(StatefulTensor): diff --git a/colossalai/zero/sharded_param/tensorful_state.py b/colossalai/zero/sharded_param/tensorful_state.py deleted file mode 100644 index a108963e5..000000000 --- a/colossalai/zero/sharded_param/tensorful_state.py +++ /dev/null @@ -1,80 +0,0 @@ -from enum import Enum -from typing import Optional -import torch - - -class TensorState(Enum): - FREE = 0 - HOLD = 1 - HOLD_AFTER_FWD = 2 - HOLD_AFTER_BWD = 3 - COMPUTE = 4 - - -class StatefulTensor(object): - """A Structure stores a Torch Tensor and labeled states. - Inspired from the paper: - PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management - - https://arxiv.org/abs/2108.05818 - """ - - def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: - self._state = state - self._payload = tensor - if self._state == TensorState.FREE: - assert self._payload is None, f"payload has to None if state is {self._state}" - - def data_ptr(self): - if self._payload is None: - return None - return self._payload.data_ptr() - - @property - def state(self) -> TensorState: - return self._state - - def set_null(self) -> None: - self._state = TensorState.FREE - self._payload = None - - def is_null(self) -> bool: - if self._state == TensorState.FREE: - assert self._payload is None - return True - return False - - def trans_state(self, state: TensorState) -> None: - self._state = state - if state == TensorState.FREE: - self._payload = None - - @property - def payload(self) -> Optional[torch.Tensor]: - return self._payload - - def copy_payload(self, tensor) -> None: - self._payload.view(-1).copy_(tensor.view(-1)) - - def reset_payload(self, tensor) -> None: - del self._payload - self._payload = tensor - self.trans_state(TensorState.HOLD) - - @property - def device(self) -> torch.device: - return self._payload.device - - @property - def dtype(self) -> torch.dtype: - return self._payload.dtype - - @property - def shape(self): - return self._payload.shape - - def to(self, device: torch.device): - raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor") - - def to_(self, device: torch.device): - raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor") diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/utils/zero_hook.py index b2fde9206..5aa9da158 100644 --- a/colossalai/zero/utils/zero_hook.py +++ b/colossalai/zero/utils/zero_hook.py @@ -8,12 +8,11 @@ from colossalai.registry import OPHOOKS from colossalai.utils import get_current_device from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.engine.ophooks import BaseOpHook from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.gemini.memory_tracer import MemStatsCollector -from typing import Any +from colossalai.gemini.stateful_tensor import TensorState @OPHOOKS.register_module diff --git a/tests/test_gemini/test_gemini_manager.py b/tests/test_gemini/test_gemini_manager.py new file mode 100644 index 000000000..0c138f101 --- /dev/null +++ b/tests/test_gemini/test_gemini_manager.py @@ -0,0 +1,73 @@ +import pytest +import torch + +from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor + + +@pytest.mark.dist +def test_gemini_manager(): + # reset the manager, in case that there exists memory information left + manager = StatefulTensor.GST_MGR + manager.reset() + + # occupation 8 + st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) + # occupation 60 + st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) + + # occupation 28 + t1 = torch.empty(7, device='cuda') + # occupation 12 + t2 = torch.empty(3, device='cpu') + st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) + st4 = StatefulTensor(None, TensorState.FREE) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 60 + assert manager.total_mem['cuda'] == 36 + assert manager.state_mem['cpu'][TensorState.HOLD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 + + st4.payload_reset(t2) + st3.payload_reset(t2) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 84 + assert manager.total_mem['cuda'] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD] == 72 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 + + st1.move_to(torch.device('cpu')) + st2.move_to(torch.device('cpu')) + st3.move_to(torch.device('cuda', 0)) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 80 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + + st1.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.HOLD_AFTER_BWD) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 + assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 + assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 + + +if __name__ == '__main__': + test_gemini_manager() diff --git a/tests/test_gemini/test_stateful_tensor_mgr.py b/tests/test_gemini/test_stateful_tensor_mgr.py index d3a71da07..10412cafe 100644 --- a/tests/test_gemini/test_stateful_tensor_mgr.py +++ b/tests/test_gemini/test_stateful_tensor_mgr.py @@ -6,9 +6,8 @@ from colossalai.utils.cuda import get_current_device from colossalai.gemini.memory_tracer import MemStatsCollector from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory import colo_set_process_memory_fraction -from colossalai.gemini import StatefulTensorMgr from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 -from colossalai.zero.sharded_param.tensorful_state import TensorState +from colossalai.gemini.stateful_tensor import TensorState from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use from torch.nn.parameter import Parameter diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index a193d9d12..0ecb7446c 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -1,7 +1,7 @@ -from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use from colossalai.zero.sharded_param import ShardedTensor +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline import colossalai import torch diff --git a/tests/test_zero/test_shard_param.py b/tests/test_zero/test_shard_param.py index 4df5f3400..8db2b7e79 100644 --- a/tests/test_zero/test_shard_param.py +++ b/tests/test_zero/test_shard_param.py @@ -11,7 +11,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from tests.test_zero.common import CONFIG, allclose -from colossalai.zero.sharded_param.tensorful_state import StatefulTensor +from colossalai.gemini.stateful_tensor import StatefulTensor @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) diff --git a/tests/test_zero/test_tensor_utils.py b/tests/test_zero/test_tensor_utils.py index 0b4201fe5..93f6c9878 100644 --- a/tests/test_zero/test_tensor_utils.py +++ b/tests/test_zero/test_tensor_utils.py @@ -2,9 +2,10 @@ import pytest import colossalai from colossalai.utils.cuda import get_current_device -from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move, - colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, - colo_model_tensor_clone) +from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move, + colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, + colo_model_tensor_clone) +from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use