[zero] refactor model data tracing (#537)

This commit is contained in:
Jiarui Fang
2022-03-28 16:38:18 +08:00
committed by GitHub
parent a590ed0ba3
commit 705f56107c
13 changed files with 98 additions and 132 deletions

View File

@@ -3,6 +3,7 @@ from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.utils import get_current_device
import torch
from typing import Tuple
class SamplingCounter:
@@ -40,6 +41,20 @@ class MemStatsCollector:
self._start_flag = False
@property
def overall_cuda(self):
return self._overall_cuda
@property
def model_data_cuda(self):
return self._model_data_cuda
@property
def non_model_data_cuda(self):
"""Non model data stats
"""
return [(v1 - v2) for v1, v2 in zip(self._overall_cuda, self._model_data_cuda)]
def start_collection(self):
self._start_flag = True
@@ -58,7 +73,7 @@ class MemStatsCollector:
self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
self._sampling_cnter.advance()
def fetch_memstats(self) -> (int, int):
def fetch_memstats(self) -> Tuple[int, int]:
"""
returns cuda usage of model data and overall cuda usage.
"""

View File

@@ -1,7 +1,8 @@
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
import torch
from typing import Union
from typing import Union, Tuple, Optional
from colossalai.logging import DistributedLogger
def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
@@ -12,60 +13,78 @@ def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
return target.numel() * target.element_size()
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
"""
Trace the model memory usage.
Args:
model (torch.nn.Module): a torch model
Returns:
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
"""
def _get_tensor_mem_use(t: Optional[torch.Tensor]):
if t is None:
return
assert isinstance(t, torch.Tensor)
_cpu_mem_usage, _cuda_mem_usage = 0, 0
if t.device.type == 'cpu':
_cpu_mem_usage += t.numel() * t.element_size()
elif t.device.type == 'cuda':
_cuda_mem_usages += t.numel() * t.element_size()
return _cuda_mem_usage, _cpu_mem_usage
cuda_mem_usage = 0
cpu_mem_usage = 0
for param in model.parameters():
if hasattr(param, 'col_attr'):
para_cuda, param_cpu = param.col_attr.get_memory_usage()
cuda_mem_usage += para_cuda
cpu_mem_usage += param_cpu
else:
t_cuda, t_cpu = _get_tensor_mem_use(param.data)
cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu
t_cuda, t_cpu = _get_tensor_mem_use(param.grad)
cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu
return cuda_mem_usage, cpu_mem_usage
class ModelDataTracer(metaclass=SingletonMeta):
"""
A tracer singleton to trace model data usage during runtime.
The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving.
To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
NOTE() now the class only trace cuda memory usage
You have to register a model on the singleton first.
"""
def __init__(self) -> None:
self._cuda_usage = 0
self._cpu_usage = 0
self._start_flag = False
self._logger = DistributedLogger("ModelDataTracer")
self._model = None
def start(self) -> None:
self._start_flag = True
def _get_mem_usage(self) -> Tuple[int, int]:
"""
get the memory usage of the model registered.
Returns:
Tuple[int, int]: cuda, cpu mem usage
"""
if self._model is None:
self._logger.warning("The Global ModelDataTracer is using, but no model is registered on it.")
return 0, 0
return col_model_data_mem_usage(self._model)
def close(self) -> None:
self._start_flag = False
def add_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
if not self._start_flag:
return
t_payload = t.payload if isinstance(t, ShardedTensor) else t
mem_use = _col_tensor_mem_usage(t_payload)
if t_payload.device.type == 'cuda':
self._cuda_usage += mem_use
elif t_payload.device.type == 'cpu':
self._cpu_usage += mem_use
else:
raise TypeError
def delete_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
if not self._start_flag:
return
t_payload = t.payload if isinstance(t, ShardedTensor) else t
mem_use = _col_tensor_mem_usage(t_payload)
if t_payload.device.type == 'cuda':
self._cuda_usage -= mem_use
elif t_payload.device.type == 'cpu':
self._cpu_usage -= mem_use
else:
raise TypeError
def clear(self) -> None:
self._cuda_usage = 0
self._cpu_usage = 0
def register_model(self, model) -> None:
self._model = model
@property
def cpu_usage(self):
return self._cpu_usage
_, cpu_usage = self._get_mem_usage()
return cpu_usage
@property
def cuda_usage(self):
return self._cuda_usage
cuda_usage, _ = self._get_mem_usage()
return cuda_usage
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()

View File

@@ -1,5 +1,4 @@
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
import torch
@@ -14,7 +13,6 @@ def test_mem_collector():
collector.sample_memstats()
m_a = torch.randn(10).cuda()
GLOBAL_MODEL_DATA_TRACER.add_tensor(m_a)
b = torch.randn(10).cuda()
# sampling at time 1
@@ -35,8 +33,7 @@ def test_mem_collector():
cuda_use, overall_use = collector.fetch_memstats()
print(cuda_use, overall_use)
print(collector._model_data_cuda)
print(collector._overall_cuda)
print(collector.overall_cuda)
if __name__ == '__main__':

View File

@@ -1,7 +1,6 @@
import torch
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Union
@@ -52,9 +51,7 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
tgt_t_payload = tgt_t.data
tgt_dev = tgt_t_payload.device
GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
tgt_t_payload.copy_(src_t_payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
# remove payload of src_t
if isinstance(src_t, ShardedTensor):
@@ -84,11 +81,7 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
# deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type:
return
if use_tracer:
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
t_payload.data = t_payload.data.to(target_device)
if use_tracer:
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
@@ -111,9 +104,7 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
return
# TODO() optimize the tensor moving with non-blocking
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
t_payload.data = t_payload.data.cpu()
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
@@ -129,5 +120,4 @@ def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device
t_payload = t.payload if isinstance(t, ShardedTensor) else t
ret = t_payload.to(target_device)
GLOBAL_MODEL_DATA_TRACER.add_tensor(ret)
return ret