mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[zero] refactor model data tracing (#537)
This commit is contained in:
@@ -3,6 +3,7 @@ from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class SamplingCounter:
|
||||
@@ -40,6 +41,20 @@ class MemStatsCollector:
|
||||
|
||||
self._start_flag = False
|
||||
|
||||
@property
|
||||
def overall_cuda(self):
|
||||
return self._overall_cuda
|
||||
|
||||
@property
|
||||
def model_data_cuda(self):
|
||||
return self._model_data_cuda
|
||||
|
||||
@property
|
||||
def non_model_data_cuda(self):
|
||||
"""Non model data stats
|
||||
"""
|
||||
return [(v1 - v2) for v1, v2 in zip(self._overall_cuda, self._model_data_cuda)]
|
||||
|
||||
def start_collection(self):
|
||||
self._start_flag = True
|
||||
|
||||
@@ -58,7 +73,7 @@ class MemStatsCollector:
|
||||
self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
||||
self._sampling_cnter.advance()
|
||||
|
||||
def fetch_memstats(self) -> (int, int):
|
||||
def fetch_memstats(self) -> Tuple[int, int]:
|
||||
"""
|
||||
returns cuda usage of model data and overall cuda usage.
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
import torch
|
||||
from typing import Union
|
||||
from typing import Union, Tuple, Optional
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
||||
@@ -12,60 +13,78 @@ def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
||||
return target.numel() * target.element_size()
|
||||
|
||||
|
||||
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
"""
|
||||
Trace the model memory usage.
|
||||
Args:
|
||||
model (torch.nn.Module): a torch model
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
|
||||
"""
|
||||
|
||||
def _get_tensor_mem_use(t: Optional[torch.Tensor]):
|
||||
if t is None:
|
||||
return
|
||||
assert isinstance(t, torch.Tensor)
|
||||
_cpu_mem_usage, _cuda_mem_usage = 0, 0
|
||||
if t.device.type == 'cpu':
|
||||
_cpu_mem_usage += t.numel() * t.element_size()
|
||||
elif t.device.type == 'cuda':
|
||||
_cuda_mem_usages += t.numel() * t.element_size()
|
||||
return _cuda_mem_usage, _cpu_mem_usage
|
||||
|
||||
cuda_mem_usage = 0
|
||||
cpu_mem_usage = 0
|
||||
for param in model.parameters():
|
||||
if hasattr(param, 'col_attr'):
|
||||
para_cuda, param_cpu = param.col_attr.get_memory_usage()
|
||||
cuda_mem_usage += para_cuda
|
||||
cpu_mem_usage += param_cpu
|
||||
else:
|
||||
t_cuda, t_cpu = _get_tensor_mem_use(param.data)
|
||||
cuda_mem_usage += t_cuda
|
||||
cpu_mem_usage += t_cpu
|
||||
t_cuda, t_cpu = _get_tensor_mem_use(param.grad)
|
||||
cuda_mem_usage += t_cuda
|
||||
cpu_mem_usage += t_cpu
|
||||
|
||||
return cuda_mem_usage, cpu_mem_usage
|
||||
|
||||
|
||||
class ModelDataTracer(metaclass=SingletonMeta):
|
||||
"""
|
||||
A tracer singleton to trace model data usage during runtime.
|
||||
The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving.
|
||||
To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
|
||||
NOTE() now the class only trace cuda memory usage
|
||||
You have to register a model on the singleton first.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cuda_usage = 0
|
||||
self._cpu_usage = 0
|
||||
self._start_flag = False
|
||||
self._logger = DistributedLogger("ModelDataTracer")
|
||||
self._model = None
|
||||
|
||||
def start(self) -> None:
|
||||
self._start_flag = True
|
||||
def _get_mem_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
get the memory usage of the model registered.
|
||||
Returns:
|
||||
Tuple[int, int]: cuda, cpu mem usage
|
||||
"""
|
||||
if self._model is None:
|
||||
self._logger.warning("The Global ModelDataTracer is using, but no model is registered on it.")
|
||||
return 0, 0
|
||||
return col_model_data_mem_usage(self._model)
|
||||
|
||||
def close(self) -> None:
|
||||
self._start_flag = False
|
||||
|
||||
def add_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
|
||||
if not self._start_flag:
|
||||
return
|
||||
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||
mem_use = _col_tensor_mem_usage(t_payload)
|
||||
if t_payload.device.type == 'cuda':
|
||||
self._cuda_usage += mem_use
|
||||
elif t_payload.device.type == 'cpu':
|
||||
self._cpu_usage += mem_use
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def delete_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
|
||||
if not self._start_flag:
|
||||
return
|
||||
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||
mem_use = _col_tensor_mem_usage(t_payload)
|
||||
if t_payload.device.type == 'cuda':
|
||||
self._cuda_usage -= mem_use
|
||||
elif t_payload.device.type == 'cpu':
|
||||
self._cpu_usage -= mem_use
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def clear(self) -> None:
|
||||
self._cuda_usage = 0
|
||||
self._cpu_usage = 0
|
||||
def register_model(self, model) -> None:
|
||||
self._model = model
|
||||
|
||||
@property
|
||||
def cpu_usage(self):
|
||||
return self._cpu_usage
|
||||
_, cpu_usage = self._get_mem_usage()
|
||||
return cpu_usage
|
||||
|
||||
@property
|
||||
def cuda_usage(self):
|
||||
return self._cuda_usage
|
||||
cuda_usage, _ = self._get_mem_usage()
|
||||
return cuda_usage
|
||||
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
import torch
|
||||
|
||||
|
||||
@@ -14,7 +13,6 @@ def test_mem_collector():
|
||||
collector.sample_memstats()
|
||||
|
||||
m_a = torch.randn(10).cuda()
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(m_a)
|
||||
b = torch.randn(10).cuda()
|
||||
|
||||
# sampling at time 1
|
||||
@@ -35,8 +33,7 @@ def test_mem_collector():
|
||||
cuda_use, overall_use = collector.fetch_memstats()
|
||||
print(cuda_use, overall_use)
|
||||
|
||||
print(collector._model_data_cuda)
|
||||
print(collector._overall_cuda)
|
||||
print(collector.overall_cuda)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
from typing import Union
|
||||
|
||||
@@ -52,9 +51,7 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
|
||||
tgt_t_payload = tgt_t.data
|
||||
tgt_dev = tgt_t_payload.device
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
|
||||
tgt_t_payload.copy_(src_t_payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
|
||||
|
||||
# remove payload of src_t
|
||||
if isinstance(src_t, ShardedTensor):
|
||||
@@ -84,11 +81,7 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
|
||||
# deal with torch.device('cpu') and torch.device('cpu:0)
|
||||
if t_payload.device.type == target_device.type:
|
||||
return
|
||||
if use_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||
t_payload.data = t_payload.data.to(target_device)
|
||||
if use_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||
|
||||
|
||||
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
||||
@@ -111,9 +104,7 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
||||
return
|
||||
|
||||
# TODO() optimize the tensor moving with non-blocking
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||
t_payload.data = t_payload.data.cpu()
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||
|
||||
|
||||
def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
||||
@@ -129,5 +120,4 @@ def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device
|
||||
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||
|
||||
ret = t_payload.to(target_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(ret)
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user