[refactor] memory utils (#577)

This commit is contained in:
Jiarui Fang
2022-04-01 09:22:33 +08:00
committed by GitHub
parent 104cbbb313
commit e956d93ac2
15 changed files with 261 additions and 202 deletions

View File

@@ -4,8 +4,8 @@ import pickle
import torch
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils import get_current_device
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
class AsyncMemoryMonitor:
@@ -82,7 +82,7 @@ class AsyncMemoryMonitor:
while self.keep_measuring:
max_usage = max(
max_usage,
colo_cuda_memory_used(),
colo_device_memory_used(get_current_device()),
)
sleep(self.interval)
return max_usage

View File

@@ -1,9 +1,9 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils import get_current_device
import torch
from typing import Tuple
from typing import List
class SamplingCounter:
@@ -23,45 +23,71 @@ class SamplingCounter:
class MemStatsCollector:
"""
A Memory statistic collector.
It works in two phases.
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
The first iteration of DNN training.
Phase 2. Runtime Phase: use the read-only collected stats
The rest iterations of DNN training.
It has a Sampling counter which is reset after DNN training iteration.
"""
def __init__(self) -> None:
"""
Collecting Memory Statistics.
It has two phases.
1. Collection Phase: collect memory usage statistics
2. Runtime Phase: do not collect statistics.
"""
self._sampling_cnter = SamplingCounter()
self._model_data_cuda = []
self._overall_cuda = []
self._model_data_cuda_list = []
self._overall_cuda_list = []
# TODO(jiaruifang) Now no cpu mem stats collecting
self._model_data_cpu = []
self._overall_cpu = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._start_flag = False
@property
def overall_cuda(self):
return self._overall_cuda
def overall_mem_stats(self, device_type: str):
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
@property
def model_data_cuda_GB(self):
return [elem / 1e9 for elem in self._model_data_cuda]
def model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
scale = 1
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
scale = 1e6
elif unit == 'KB':
scale = 1e3
else:
raise TypeError
@property
def model_data_cuda(self):
return self._model_data_cuda
if device_type == 'cuda':
return [elem / scale for elem in self._model_data_cuda_list]
elif device_type == 'cpu':
return [elem / scale for elem in self._model_data_cpu_list]
else:
raise TypeError
@property
def non_model_data_cuda_GB(self):
return [elem / 1e9 for elem in self.non_model_data_cuda]
@property
def non_model_data_cuda(self):
def non_model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
"""Non model data stats
"""
return [(v1 - v2) for v1, v2 in zip(self._overall_cuda, self._model_data_cuda)]
scale = 1
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
scale = 1e6
elif unit == 'KB':
scale = 1e3
if device_type == 'cuda':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)]
elif device_type == 'cpu':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cpu_list, self._model_data_cpu_list)]
else:
raise TypeError
def start_collection(self):
self._start_flag = True
@@ -73,32 +99,28 @@ class MemStatsCollector:
"""
Sampling memory statistics.
Record the current model data CUDA memory usage as well as system CUDA memory usage.
Advance the sampling cnter.
"""
if self._start_flag:
sampling_cnt = self._sampling_cnter.sampling_cnt
assert sampling_cnt == len(self._overall_cuda)
self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
self._sampling_cnter.advance()
assert sampling_cnt == len(self._overall_cuda_list)
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(colo_device_memory_used(get_current_device()))
def fetch_memstats(self) -> Tuple[int, int]:
"""
returns cuda usage of model data and overall cuda usage.
"""
sampling_cnt = self._sampling_cnter.sampling_cnt
if len(self._model_data_cuda) < sampling_cnt:
raise RuntimeError
return (self._model_data_cuda[sampling_cnt], self._overall_cuda[sampling_cnt])
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
self._sampling_cnter.advance()
def reset_sampling_cnter(self) -> None:
self._sampling_cnter.reset()
def clear(self) -> None:
self._model_data_cuda = []
self._overall_cuda = []
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu = []
self._overall_cpu = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._start_flag = False
self._sampling_cnter.reset()

View File

@@ -30,10 +30,7 @@ def test_mem_collector():
collector.sample_memstats()
collector.sample_memstats()
cuda_use, overall_use = collector.fetch_memstats()
print(cuda_use, overall_use)
print(collector.overall_cuda)
print(collector.overall_mem_stats('cuda'))
if __name__ == '__main__':

View File

@@ -9,29 +9,6 @@ import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils.cuda import get_current_device
from typing import Optional
def colo_cuda_memory_used(device: Optional[torch.device] = None) -> int:
"""Get the free memory info of device.
Args:
device (Optional[``torch.device``]): a torch device instance or None. Defaults None.
Returns:
int: current memory usage, sized by Byte.
"""
if device:
assert device.type == 'cuda'
else:
device = torch.device(f'cuda:{get_current_device()}')
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
def bytes_to_GB(val, decimal=2):

View File

@@ -1,29 +1,65 @@
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from typing import Tuple, Union
from collections import namedtuple
import psutil
from colossalai.core import global_context as gpc
_GLOBAL_CUDA_MEM_FRACTION = 1.0
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if issubclass(type(tensor), StatefulTensor):
t = tensor.payload
elif isinstance(tensor, torch.Tensor):
t = tensor
else:
return 0, 0
# copy from PatrickStar
def _get_cpu_memory_info():
ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"])
try:
# psutil reads the memory info from /proc/memory_info,
# which results in returning the host memory instead of
# that of container.
# Here we try to read the container memory with method in:
# https://stackoverflow.com/a/46213331/5163915
mems = {}
with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f:
for line in f:
fields = line.split()
mems[fields[0]] = int(fields[1]) * 1024
total = mems[b"MemTotal:"]
free = mems[b"MemFree:"]
cached = mems[b"Cached:"]
buffers = mems[b"Buffers:"]
used = total - free - cached - buffers
if used < 0:
used = total - free
mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used)
except FileNotFoundError:
mems = psutil.virtual_memory()
mem_info = ps_mem_info(
total=mems.total,
free=mems.free,
cached=mems.cached,
buffers=mems.buffers,
used=mems.used,
)
return mem_info
cuda_use, cpu_use = 0, 0
mem_use = t.numel() * t.element_size()
if t.device.type == 'cuda':
cuda_use += mem_use
elif t.device.type == 'cpu':
cpu_use += mem_use
return cuda_use, cpu_use
def colo_device_memory_used(device) -> int:
if not isinstance(device, torch.device):
device = torch.device(f"cuda:{device}")
if device.type == 'cpu':
mem_info = _get_cpu_memory_info()
# FIXME(jiaruifang) only work for 1-CPU multi-GPU
# CPU memory is sharded with all processes
# Not support multi-GPU multi-CPU
# We need a local_world_size here
ret = mem_info.used / gpc.get_world_size(ParallelMode.DATA)
return ret
elif device.type == 'cuda':
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
def colo_set_process_memory_fraction(ratio: float) -> None:
@@ -44,97 +80,3 @@ def colo_cuda_memory_capacity() -> float:
Get cuda memory capacity of the current cuda.
"""
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
torch.Tensor]) -> None:
"""
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
Args:
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
"""
if issubclass(type(src_t), StatefulTensor):
src_t_payload = src_t.payload
else:
src_t_payload = src_t.data
src_dev = src_t_payload.device
if issubclass(type(tgt_t), StatefulTensor):
tgt_t_payload = tgt_t.payload
else:
tgt_t_payload = tgt_t.data
tgt_t_payload.copy_(src_t_payload)
# remove payload of src_t
if issubclass(type(src_t), StatefulTensor):
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
else:
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
int]) -> None:
"""
move a tensor to the target_device
Args:
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
"""
if isinstance(t, torch.Tensor):
t_payload = t
elif issubclass(type(t), StatefulTensor):
t_payload = t.payload
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if isinstance(target_device, int):
target_device = torch.device(f'cuda:{target_device}')
# deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type:
return
t_payload.data = t_payload.data.to(target_device)
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu
move a model data tensor from gpu to cpu
Args:
t (Union[StatefulTensor, torch.Tensor]): _description_
"""
if issubclass(type(t), StatefulTensor):
t_payload = t.payload
elif isinstance(t, torch.Tensor):
t_payload = t
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if t_payload.device.type == 'cpu':
return
# TODO() optimize the tensor moving with non-blocking
t_payload.data = t_payload.data.cpu()
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
"""
Clone a model data tensor
Args:
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
t_payload = t.payload if issubclass(type(t), StatefulTensor) else t
ret = t_payload.to(target_device)
return ret