mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-11 12:51:55 +00:00
[gemini] collect cpu-gpu moving volume in each iteration (#813)
This commit is contained in:
parent
61c20b44bc
commit
3ddbd1bce1
@ -28,6 +28,8 @@ class StatefulTensorMgr(object):
|
|||||||
self._compute_list: List[StatefulTensor] = []
|
self._compute_list: List[StatefulTensor] = []
|
||||||
self._compute_idx: int = -1
|
self._compute_idx: int = -1
|
||||||
|
|
||||||
|
self._cpu_gpu_move_volume = 0
|
||||||
|
|
||||||
def register_stateful_param(self, param: ShardedParamV2) -> None:
|
def register_stateful_param(self, param: ShardedParamV2) -> None:
|
||||||
assert isinstance(param, ShardedParamV2)
|
assert isinstance(param, ShardedParamV2)
|
||||||
for t in param.get_payload_tensors():
|
for t in param.get_payload_tensors():
|
||||||
@ -56,20 +58,26 @@ class StatefulTensorMgr(object):
|
|||||||
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
|
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
|
self._cpu_gpu_move_volume += self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
|
||||||
cuda_demand=cuda_demand,
|
cuda_demand=cuda_demand,
|
||||||
warmup=self._warmup,
|
warmup=self._warmup,
|
||||||
compute_list=self._compute_list,
|
compute_list=self._compute_list,
|
||||||
compute_idx=self._compute_idx)
|
compute_idx=self._compute_idx)
|
||||||
# move COMPUTE tensors to CUDA
|
# move COMPUTE tensors to CUDA
|
||||||
for t in move_to_cuda_tensor_list:
|
for t in move_to_cuda_tensor_list:
|
||||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||||
|
self._cpu_gpu_move_volume += t.payload.numel() * t.payload.element_size()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cpu_gpu_move_volume(self):
|
||||||
|
return self._cpu_gpu_move_volume
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""This function must be called when each iteration finishes
|
"""This function must be called when each iteration finishes
|
||||||
"""
|
"""
|
||||||
self._warmup = False
|
self._warmup = False
|
||||||
self._compute_idx = -1
|
self._compute_idx = -1
|
||||||
|
self._cpu_gpu_move_volume = 0
|
||||||
|
|
||||||
def _trans_state(self, trans_state_func, stateful_tensor, state):
|
def _trans_state(self, trans_state_func, stateful_tensor, state):
|
||||||
trans_state_func(state)
|
trans_state_func(state)
|
||||||
|
@ -27,9 +27,12 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
|
|||||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||||
super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector)
|
super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector)
|
||||||
|
|
||||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
|
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
|
||||||
|
volume = 0
|
||||||
for t in hold_cuda_tensor_list:
|
for t in hold_cuda_tensor_list:
|
||||||
colo_model_data_tensor_move_inline(t, self.device)
|
colo_model_data_tensor_move_inline(t, self.device)
|
||||||
|
volume += t.payload.numel() * t.payload.element_size()
|
||||||
|
return volume
|
||||||
|
|
||||||
|
|
||||||
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
||||||
@ -38,8 +41,8 @@ class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
|||||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||||
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
|
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
|
||||||
|
|
||||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
|
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
|
||||||
pass
|
return 0
|
||||||
|
|
||||||
|
|
||||||
class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||||
@ -57,7 +60,24 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||||||
warmup: bool = True,
|
warmup: bool = True,
|
||||||
compute_list: List[StatefulTensor] = [],
|
compute_list: List[StatefulTensor] = [],
|
||||||
compute_idx: int = 0,
|
compute_idx: int = 0,
|
||||||
**kwargs) -> None:
|
**kwargs) -> int:
|
||||||
|
"""
|
||||||
|
Evict tensors from CUDA device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
|
||||||
|
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
|
||||||
|
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
|
||||||
|
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
|
||||||
|
compute_idx (int, optional): the idx of computing device. Defaults to 0.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: the volume of memory that is evicted
|
||||||
|
"""
|
||||||
|
volume = 0
|
||||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||||
used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage
|
used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage
|
||||||
if warmup:
|
if warmup:
|
||||||
@ -87,11 +107,14 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||||||
break
|
break
|
||||||
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
|
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
|
||||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||||
|
volume += t.payload.numel() * t.payload.element_size()
|
||||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return volume
|
||||||
|
|
||||||
|
|
||||||
class TensorPlacementPolicyFactory:
|
class TensorPlacementPolicyFactory:
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.registry import OPHOOKS
|
from colossalai.registry import OPHOOKS
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
@ -27,6 +28,7 @@ class ZeroHook(BaseOpHook):
|
|||||||
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
||||||
process_group: Optional[dist.ProcessGroup] = None):
|
process_group: Optional[dist.ProcessGroup] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.logger = get_dist_logger("ZeROHook")
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
@ -112,4 +114,6 @@ class ZeroHook(BaseOpHook):
|
|||||||
|
|
||||||
def post_iter(self):
|
def post_iter(self):
|
||||||
if self._stateful_tensor_mgr:
|
if self._stateful_tensor_mgr:
|
||||||
|
self.logger.info(
|
||||||
|
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB", ranks=[0])
|
||||||
self._stateful_tensor_mgr.reset()
|
self._stateful_tensor_mgr.reset()
|
||||||
|
Loading…
Reference in New Issue
Block a user