mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[gemini] add GeminiMemoryManger (#832)
* refactor StatefulTensor, tensor utilities * add unitest for GeminiMemoryManager
This commit is contained in:
@@ -35,4 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
|
||||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_zerov2', 'ShardedModelV2', 'ShardedOptimizerV2']
|
||||
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2']
|
||||
|
@@ -184,11 +184,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(target_device)
|
||||
|
||||
param.colo_attr = ShardedParamV2(param, set_data_none=False)
|
||||
param.colo_attr = ShardedParamV2(param, set_data_none=True)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
param.data = param.colo_attr.data_payload # set param.data to payload
|
||||
|
||||
param.data = param.colo_attr.data_payload # set param.data to payload
|
||||
|
||||
# mark whether the param is replicated
|
||||
param.colo_attr.is_replicated = self.is_replicated
|
||||
|
@@ -31,9 +31,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
||||
# Release payload here, to decrease peak memory usage
|
||||
for t in tensor_list:
|
||||
t.reset_payload(None)
|
||||
else:
|
||||
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
|
||||
@@ -44,6 +41,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
for i, t in enumerate(tensor_list):
|
||||
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
|
||||
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
|
||||
t.reset_payload(gathered_payload)
|
||||
t.payload_reset(gathered_payload)
|
||||
t.is_sharded = False
|
||||
offset += tensor_numels[i]
|
||||
|
@@ -3,10 +3,10 @@ from typing import List, Optional
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.shard_utils.commons import get_shard
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
@@ -36,7 +36,7 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||
f" but current cuda device is {get_current_device()}"
|
||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||
t.reset_payload(sharded_payload)
|
||||
t.payload_reset(sharded_payload)
|
||||
t.is_sharded = True
|
||||
|
||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
@@ -53,6 +53,6 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
||||
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||
t.reset_payload(gathered_payload)
|
||||
t.payload_reset(gathered_payload)
|
||||
colo_model_data_tensor_move_inline(t, target_device)
|
||||
t.is_sharded = False
|
||||
|
@@ -3,7 +3,7 @@ from typing import Any, Callable, List, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Union
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
|
||||
|
||||
def get_gradient_predivide_factor(world_size: int) -> float:
|
||||
|
@@ -17,11 +17,11 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.gemini.stateful_tensor import TensorState
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
|
||||
|
||||
@@ -358,8 +358,11 @@ class ShardedModelV2(nn.Module):
|
||||
assert param.colo_attr.saved_grad.is_null(
|
||||
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
|
||||
param.colo_attr.reset_grad_payload(grad.data)
|
||||
param.colo_attr.reset_data_payload(grad.data) # release the memory of param
|
||||
param.colo_attr.grad_payload_reset(grad.data)
|
||||
# release the memory of param
|
||||
# we set a false None for parameter's payload
|
||||
# so we can get paramter's device and dtype later in optimizer
|
||||
param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype))
|
||||
|
||||
if param.colo_attr.is_replicated:
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||
@@ -368,7 +371,7 @@ class ShardedModelV2(nn.Module):
|
||||
fp32_grad = cast_tensor_to_fp32(grad)
|
||||
|
||||
if param.colo_attr.saved_grad.is_null():
|
||||
param.colo_attr.reset_grad_payload(fp32_grad)
|
||||
param.colo_attr.grad_payload_reset(fp32_grad)
|
||||
else:
|
||||
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
|
||||
|
||||
|
@@ -12,15 +12,15 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState)
|
||||
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
|
||||
|
||||
@@ -253,7 +253,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
for p in group['params']:
|
||||
# p.colo_attr.sharded_data_tensor stores grad now
|
||||
# we have to recover fp16 param
|
||||
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
|
||||
reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
|
||||
if recover_data and reuse_fp16_shard:
|
||||
self._copy_master_param_to_param_fp16(p)
|
||||
else:
|
||||
@@ -332,12 +332,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
|
||||
def _copy_master_param_to_param_fp16(self, p):
|
||||
# flush gradient
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
if p.colo_attr.sharded_data_tensor.payload_size == 0:
|
||||
# here reuse_fp16_shard is True
|
||||
# in order to use copy below, we should give sharded data tensor a payload
|
||||
p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad)
|
||||
else:
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
|
||||
p.data = self.master_params[p].payload
|
||||
|
||||
# we need to allocate new memory for keep_not_shard paramters
|
||||
# in order to use copy, otherwise, the sizes of tensor is not compatible
|
||||
if p.colo_attr.data_payload.numel() != p.data.numel():
|
||||
p.colo_attr.data_payload_reset(
|
||||
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.data = self.master_params[p].payload
|
||||
p.colo_attr.reset_data_payload(
|
||||
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
|
||||
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
|
||||
p.colo_attr.set_data_none()
|
||||
|
||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||
|
@@ -1,11 +1,5 @@
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline,
|
||||
colo_model_data_move_to_cpu, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor
|
||||
|
||||
__all__ = [
|
||||
'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline',
|
||||
'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor'
|
||||
]
|
||||
'ShardedTensor', 'ShardedParamV2']
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Optional, Tuple
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
|
||||
from .tensorful_state import StatefulTensor, TensorState
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.gemini.tensor_utils import colo_tensor_mem_usage
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from typing import List
|
||||
|
||||
EMPTY_TENSOR_DICT = {}
|
||||
@@ -50,6 +50,7 @@ class ShardedParamV2(object):
|
||||
|
||||
@property
|
||||
def data_payload(self):
|
||||
assert not self.sharded_data_tensor.is_null()
|
||||
return self.sharded_data_tensor.payload
|
||||
|
||||
@property
|
||||
@@ -61,15 +62,15 @@ class ShardedParamV2(object):
|
||||
def param_is_sharded(self):
|
||||
return self.sharded_data_tensor.is_sharded
|
||||
|
||||
def reset_data_payload(self, tensor: torch.Tensor):
|
||||
def data_payload_reset(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.sharded_data_tensor.reset_payload(tensor)
|
||||
self.sharded_data_tensor.payload_reset(tensor)
|
||||
|
||||
def reset_grad_payload(self, tensor: torch.Tensor):
|
||||
def grad_payload_reset(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.saved_grad.reset_payload(tensor)
|
||||
self.saved_grad.payload_reset(tensor)
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from typing import Optional
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
|
||||
class ShardedTensor(StatefulTensor):
|
||||
|
@@ -1,117 +0,0 @@
|
||||
import torch
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
from typing import Union, Tuple
|
||||
|
||||
|
||||
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
|
||||
if issubclass(type(tensor), StatefulTensor):
|
||||
t = tensor.payload
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
t = tensor
|
||||
else:
|
||||
return 0, 0
|
||||
|
||||
cuda_use, cpu_use = 0, 0
|
||||
|
||||
mem_use = t.storage().size() * t.element_size()
|
||||
if t.device.type == 'cuda':
|
||||
cuda_use += mem_use
|
||||
elif t.device.type == 'cpu':
|
||||
cpu_use += mem_use
|
||||
|
||||
return cuda_use, cpu_use
|
||||
|
||||
|
||||
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
|
||||
torch.Tensor]) -> None:
|
||||
"""
|
||||
A colossal API for model data tensor move.
|
||||
The src and target tensors could be resident on both CPU and GPU.
|
||||
|
||||
NOTE() The source tensor payload will be removed after this function.
|
||||
|
||||
The function will record the communication volume between CPU and GPU.
|
||||
Args:
|
||||
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
|
||||
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
|
||||
"""
|
||||
if issubclass(type(src_t), StatefulTensor):
|
||||
src_t_payload = src_t.payload
|
||||
else:
|
||||
src_t_payload = src_t.data
|
||||
src_dev = src_t_payload.device
|
||||
if issubclass(type(tgt_t), StatefulTensor):
|
||||
tgt_t_payload = tgt_t.payload
|
||||
else:
|
||||
tgt_t_payload = tgt_t.data
|
||||
|
||||
tgt_t_payload.copy_(src_t_payload)
|
||||
|
||||
# remove payload of src_t
|
||||
if issubclass(type(src_t), StatefulTensor):
|
||||
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
|
||||
else:
|
||||
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
|
||||
|
||||
|
||||
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
|
||||
int]) -> None:
|
||||
"""
|
||||
move a tensor to the target_device
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
|
||||
target_device: a traget device, if type is int, it the index of cuda card.
|
||||
"""
|
||||
if isinstance(t, torch.Tensor):
|
||||
t_payload = t
|
||||
elif issubclass(type(t), StatefulTensor):
|
||||
t_payload = t.payload
|
||||
else:
|
||||
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||
|
||||
if not isinstance(target_device, torch.device):
|
||||
target_device = torch.device(f'cuda:{target_device}')
|
||||
|
||||
# deal with torch.device('cpu') and torch.device('cpu:0)
|
||||
if t_payload.device.type == target_device.type:
|
||||
return
|
||||
t_payload.data = t_payload.data.to(target_device)
|
||||
|
||||
|
||||
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
|
||||
"""colo_model_data_move_to_cpu
|
||||
|
||||
move a model data tensor from gpu to cpu
|
||||
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): _description_
|
||||
"""
|
||||
|
||||
if issubclass(type(t), StatefulTensor):
|
||||
t_payload = t.payload
|
||||
elif isinstance(t, torch.Tensor):
|
||||
t_payload = t
|
||||
else:
|
||||
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||
|
||||
if t_payload.device.type == 'cpu':
|
||||
return
|
||||
|
||||
# TODO() optimize the tensor moving with non-blocking
|
||||
t_payload.data = t_payload.data.cpu()
|
||||
|
||||
|
||||
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Clone a model data tensor
|
||||
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
|
||||
target_device (torch.device): the target device
|
||||
Returns:
|
||||
torch.Tensor: a cloned torch tensor
|
||||
"""
|
||||
t_payload = t.payload if issubclass(type(t), StatefulTensor) else t
|
||||
|
||||
ret = t_payload.to(target_device)
|
||||
return ret
|
@@ -1,80 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
HOLD = 1
|
||||
HOLD_AFTER_FWD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
COMPUTE = 4
|
||||
|
||||
|
||||
class StatefulTensor(object):
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
Inspired from the paper:
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
self._payload = tensor
|
||||
if self._state == TensorState.FREE:
|
||||
assert self._payload is None, f"payload has to None if state is {self._state}"
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
return None
|
||||
return self._payload.data_ptr()
|
||||
|
||||
@property
|
||||
def state(self) -> TensorState:
|
||||
return self._state
|
||||
|
||||
def set_null(self) -> None:
|
||||
self._state = TensorState.FREE
|
||||
self._payload = None
|
||||
|
||||
def is_null(self) -> bool:
|
||||
if self._state == TensorState.FREE:
|
||||
assert self._payload is None
|
||||
return True
|
||||
return False
|
||||
|
||||
def trans_state(self, state: TensorState) -> None:
|
||||
self._state = state
|
||||
if state == TensorState.FREE:
|
||||
self._payload = None
|
||||
|
||||
@property
|
||||
def payload(self) -> Optional[torch.Tensor]:
|
||||
return self._payload
|
||||
|
||||
def copy_payload(self, tensor) -> None:
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def reset_payload(self, tensor) -> None:
|
||||
del self._payload
|
||||
self._payload = tensor
|
||||
self.trans_state(TensorState.HOLD)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._payload.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
||||
def to(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
||||
|
||||
def to_(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
|
@@ -8,12 +8,11 @@ from colossalai.registry import OPHOOKS
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||
from typing import Any
|
||||
from colossalai.gemini.stateful_tensor import TensorState
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
|
Reference in New Issue
Block a user