mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[utils] correct cpu memory used and capacity in the context of multi-process (#726)
This commit is contained in:
@@ -8,6 +8,7 @@ from colossalai.utils import get_current_device
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.logging import get_dist_logger
|
||||
from packaging import version
|
||||
|
||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||
|
||||
@@ -106,7 +107,8 @@ def colo_device_memory_capacity(device: torch.device) -> int:
|
||||
assert isinstance(device, torch.device)
|
||||
if device.type == 'cpu':
|
||||
mem_info = _get_cpu_memory_info()
|
||||
return mem_info.info.total / gpc.get_world_size(ParallelMode.DATA)
|
||||
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
|
||||
return mem_info.total / gpc.num_processes_on_current_node
|
||||
if device.type == 'cuda':
|
||||
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
|
||||
|
||||
@@ -123,8 +125,9 @@ def colo_device_memory_used(device: torch.device) -> int:
|
||||
"""
|
||||
if device.type == 'cpu':
|
||||
mem_info = _get_cpu_memory_info()
|
||||
# FIXME(jiaruifang) we need get how many processes are using the CPU memory.
|
||||
ret = mem_info.used / gpc.get_world_size(ParallelMode.DATA)
|
||||
# In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used.
|
||||
# Each process consumes the same amount of memory.
|
||||
ret = mem_info.used / gpc.num_processes_on_current_node
|
||||
return ret
|
||||
elif device.type == 'cuda':
|
||||
ret: int = torch.cuda.memory_allocated(device)
|
||||
@@ -142,6 +145,10 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
|
||||
Args:
|
||||
ratio (float): a ratio between 0. ~ 1.
|
||||
"""
|
||||
if version.parse(torch.__version__) < version.parse('1.8'):
|
||||
logger = get_dist_logger('colo_set_process_memory_fraction')
|
||||
logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8')
|
||||
return
|
||||
global _GLOBAL_CUDA_MEM_FRACTION
|
||||
_GLOBAL_CUDA_MEM_FRACTION = ratio
|
||||
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
|
||||
|
Reference in New Issue
Block a user