[gemini] add GeminiMemoryManger (#832)

* refactor StatefulTensor, tensor utilities

* add unitest for GeminiMemoryManager
This commit is contained in:
HELSON
2022-04-24 13:08:48 +08:00
committed by GitHub
parent 35ea6e1023
commit e5ea3fdeef
23 changed files with 414 additions and 180 deletions

View File

@@ -35,4 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
return zero_model, zero_optimizer
__all__ = ['convert_to_zerov2', 'ShardedModelV2', 'ShardedOptimizerV2']
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2']

View File

@@ -184,11 +184,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.colo_attr = ShardedParamV2(param, set_data_none=False)
param.colo_attr = ShardedParamV2(param, set_data_none=True)
if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.data_payload # set param.data to payload
param.data = param.colo_attr.data_payload # set param.data to payload
# mark whether the param is replicated
param.colo_attr.is_replicated = self.is_replicated

View File

@@ -31,9 +31,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
for i in range(world_size):
if i == rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
# Release payload here, to decrease peak memory usage
for t in tensor_list:
t.reset_payload(None)
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
@@ -44,6 +41,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
for i, t in enumerate(tensor_list):
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
t.reset_payload(gathered_payload)
t.payload_reset(gathered_payload)
t.is_sharded = False
offset += tensor_numels[i]

View File

@@ -3,10 +3,10 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline
class TensorShardStrategy(BaseShardStrategy):
@@ -36,7 +36,7 @@ class TensorShardStrategy(BaseShardStrategy):
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
f" but current cuda device is {get_current_device()}"
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.reset_payload(sharded_payload)
t.payload_reset(sharded_payload)
t.is_sharded = True
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
@@ -53,6 +53,6 @@ class TensorShardStrategy(BaseShardStrategy):
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload)
t.payload_reset(gathered_payload)
colo_model_data_tensor_move_inline(t, target_device)
t.is_sharded = False

View File

@@ -3,7 +3,7 @@ from typing import Any, Callable, List, Tuple
import torch
import torch.nn.functional as F
from typing import Union
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.stateful_tensor import StatefulTensor
def get_gradient_predivide_factor(world_size: int) -> float:

View File

@@ -17,11 +17,11 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
@@ -358,8 +358,11 @@ class ShardedModelV2(nn.Module):
assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.reset_grad_payload(grad.data)
param.colo_attr.reset_data_payload(grad.data) # release the memory of param
param.colo_attr.grad_payload_reset(grad.data)
# release the memory of param
# we set a false None for parameter's payload
# so we can get paramter's device and dtype later in optimizer
param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype))
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
@@ -368,7 +371,7 @@ class ShardedModelV2(nn.Module):
fp32_grad = cast_tensor_to_fp32(grad)
if param.colo_attr.saved_grad.is_null():
param.colo_attr.reset_grad_payload(fp32_grad)
param.colo_attr.grad_payload_reset(fp32_grad)
else:
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))

View File

@@ -12,15 +12,15 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.gemini.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState)
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
@@ -253,7 +253,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for p in group['params']:
# p.colo_attr.sharded_data_tensor stores grad now
# we have to recover fp16 param
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
if recover_data and reuse_fp16_shard:
self._copy_master_param_to_param_fp16(p)
else:
@@ -332,12 +332,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def _copy_master_param_to_param_fp16(self, p):
# flush gradient
p.colo_attr.saved_grad.set_null()
if p.colo_attr.sharded_data_tensor.payload_size == 0:
# here reuse_fp16_shard is True
# in order to use copy below, we should give sharded data tensor a payload
p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad)
else:
p.colo_attr.saved_grad.set_null()
p.data = self.master_params[p].payload
# we need to allocate new memory for keep_not_shard paramters
# in order to use copy, otherwise, the sizes of tensor is not compatible
if p.colo_attr.data_payload.numel() != p.data.numel():
p.colo_attr.data_payload_reset(
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.data = self.master_params[p].payload
p.colo_attr.reset_data_payload(
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
p.colo_attr.set_data_none()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:

View File

@@ -1,11 +1,5 @@
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline,
colo_model_data_move_to_cpu, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor
__all__ = [
'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline',
'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor'
]
'ShardedTensor', 'ShardedParamV2']

View File

@@ -1,8 +1,8 @@
import torch
from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional, Tuple
from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
from .tensorful_state import StatefulTensor, TensorState
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.gemini.tensor_utils import colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from typing import List
EMPTY_TENSOR_DICT = {}
@@ -50,6 +50,7 @@ class ShardedParamV2(object):
@property
def data_payload(self):
assert not self.sharded_data_tensor.is_null()
return self.sharded_data_tensor.payload
@property
@@ -61,15 +62,15 @@ class ShardedParamV2(object):
def param_is_sharded(self):
return self.sharded_data_tensor.is_sharded
def reset_data_payload(self, tensor: torch.Tensor):
def data_payload_reset(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.sharded_data_tensor.reset_payload(tensor)
self.sharded_data_tensor.payload_reset(tensor)
def reset_grad_payload(self, tensor: torch.Tensor):
def grad_payload_reset(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.saved_grad.reset_payload(tensor)
self.saved_grad.payload_reset(tensor)
def get_memory_usage(self) -> Tuple[int, int]:
"""

View File

@@ -1,6 +1,5 @@
import torch
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from typing import Optional
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
class ShardedTensor(StatefulTensor):

View File

@@ -1,117 +0,0 @@
import torch
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from typing import Union, Tuple
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if issubclass(type(tensor), StatefulTensor):
t = tensor.payload
elif isinstance(tensor, torch.Tensor):
t = tensor
else:
return 0, 0
cuda_use, cpu_use = 0, 0
mem_use = t.storage().size() * t.element_size()
if t.device.type == 'cuda':
cuda_use += mem_use
elif t.device.type == 'cpu':
cpu_use += mem_use
return cuda_use, cpu_use
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
torch.Tensor]) -> None:
"""
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
Args:
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
"""
if issubclass(type(src_t), StatefulTensor):
src_t_payload = src_t.payload
else:
src_t_payload = src_t.data
src_dev = src_t_payload.device
if issubclass(type(tgt_t), StatefulTensor):
tgt_t_payload = tgt_t.payload
else:
tgt_t_payload = tgt_t.data
tgt_t_payload.copy_(src_t_payload)
# remove payload of src_t
if issubclass(type(src_t), StatefulTensor):
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
else:
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
int]) -> None:
"""
move a tensor to the target_device
Args:
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
target_device: a traget device, if type is int, it the index of cuda card.
"""
if isinstance(t, torch.Tensor):
t_payload = t
elif issubclass(type(t), StatefulTensor):
t_payload = t.payload
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if not isinstance(target_device, torch.device):
target_device = torch.device(f'cuda:{target_device}')
# deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type:
return
t_payload.data = t_payload.data.to(target_device)
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu
move a model data tensor from gpu to cpu
Args:
t (Union[StatefulTensor, torch.Tensor]): _description_
"""
if issubclass(type(t), StatefulTensor):
t_payload = t.payload
elif isinstance(t, torch.Tensor):
t_payload = t
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if t_payload.device.type == 'cpu':
return
# TODO() optimize the tensor moving with non-blocking
t_payload.data = t_payload.data.cpu()
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
"""
Clone a model data tensor
Args:
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
t_payload = t.payload if issubclass(type(t), StatefulTensor) else t
ret = t_payload.to(target_device)
return ret

View File

@@ -1,80 +0,0 @@
from enum import Enum
from typing import Optional
import torch
class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = tensor
if self._state == TensorState.FREE:
assert self._payload is None, f"payload has to None if state is {self._state}"
def data_ptr(self):
if self._payload is None:
return None
return self._payload.data_ptr()
@property
def state(self) -> TensorState:
return self._state
def set_null(self) -> None:
self._state = TensorState.FREE
self._payload = None
def is_null(self) -> bool:
if self._state == TensorState.FREE:
assert self._payload is None
return True
return False
def trans_state(self, state: TensorState) -> None:
self._state = state
if state == TensorState.FREE:
self._payload = None
@property
def payload(self) -> Optional[torch.Tensor]:
return self._payload
def copy_payload(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor) -> None:
del self._payload
self._payload = tensor
self.trans_state(TensorState.HOLD)
@property
def device(self) -> torch.device:
return self._payload.device
@property
def dtype(self) -> torch.dtype:
return self._payload.dtype
@property
def shape(self):
return self._payload.shape
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")

View File

@@ -8,12 +8,11 @@ from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.memory_tracer import MemStatsCollector
from typing import Any
from colossalai.gemini.stateful_tensor import TensorState
@OPHOOKS.register_module