mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[zero] fix init device bug in zero init context unittest (#516)
This commit is contained in:
@@ -2,26 +2,10 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from time import sleep, time
|
||||
import pickle
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
import torch
|
||||
|
||||
|
||||
def get_cuda_memory_used(device: torch.device) -> int:
|
||||
"""
|
||||
Get the free memory info of device.
|
||||
:param device: device id
|
||||
:type device: torch.device
|
||||
:return: current memory usage, sized by MB
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
assert device.type == 'cuda'
|
||||
|
||||
ret: int = torch.cuda.memory_allocated(device)
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return ret
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||
|
||||
|
||||
class AsyncMemoryMonitor:
|
||||
@@ -97,7 +81,7 @@ class AsyncMemoryMonitor:
|
||||
while self.keep_measuring:
|
||||
max_usage = max(
|
||||
max_usage,
|
||||
get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')),
|
||||
colo_cuda_memory_used(),
|
||||
)
|
||||
sleep(self.interval)
|
||||
return max_usage
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from .async_memtracer import get_cuda_memory_used
|
||||
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
import torch
|
||||
@@ -55,7 +55,7 @@ class MemStatsCollector:
|
||||
sampling_cnt = self._sampling_cnter.sampling_cnt
|
||||
assert sampling_cnt == len(self._overall_cuda)
|
||||
self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
||||
self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
||||
self._sampling_cnter.advance()
|
||||
|
||||
def fetch_memstats(self) -> (int, int):
|
||||
|
@@ -44,6 +44,9 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
||||
mem_use = _col_tensor_mem_usage(t)
|
||||
self._cuda_usage -= mem_use
|
||||
|
||||
def clear(self) -> None:
|
||||
self._cuda_usage = 0
|
||||
|
||||
@property
|
||||
def cpu_usage(self):
|
||||
return self._cpu_usage
|
||||
|
@@ -9,6 +9,28 @@ import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def colo_cuda_memory_used(device: Optional[torch.device] = None) -> int:
|
||||
"""
|
||||
Get the free memory info of device.
|
||||
:param device: a torch device instance or None
|
||||
:type device: Optional[torch.device]
|
||||
:return: current memory usage, sized by Byte
|
||||
:rtype: int
|
||||
"""
|
||||
if device:
|
||||
assert device.type == 'cuda'
|
||||
else:
|
||||
device = torch.device(f'cuda:{get_current_device()}')
|
||||
|
||||
ret: int = torch.cuda.memory_allocated(device)
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return ret
|
||||
|
||||
|
||||
def bytes_to_GB(val, decimal=2):
|
||||
|
@@ -3,7 +3,7 @@ from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||
|
||||
|
Reference in New Issue
Block a user