mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-07 11:03:58 +00:00
[refactor] refactor the memory utils (#715)
This commit is contained in:
parent
dbd96fe90a
commit
193dc8dacb
@ -30,7 +30,7 @@ class ZeroHook(BaseOpHook):
|
|||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
||||||
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
self.computing_device = get_current_device()
|
||||||
|
|
||||||
self._memstarts_collector = memstarts_collector
|
self._memstarts_collector = memstarts_collector
|
||||||
self._stateful_tensor_mgr = stateful_tensor_mgr
|
self._stateful_tensor_mgr = stateful_tensor_mgr
|
||||||
|
@ -8,7 +8,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
|
|||||||
sync_model_param, disposable)
|
sync_model_param, disposable)
|
||||||
from .data_sampler import DataParallelSampler, get_dataloader
|
from .data_sampler import DataParallelSampler, get_dataloader
|
||||||
from .gradient_accumulation import accumulate_gradient
|
from .gradient_accumulation import accumulate_gradient
|
||||||
from .memory_utils.memory_monitor import report_memory_usage
|
from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity
|
||||||
from .timer import MultiTimer, Timer
|
from .timer import MultiTimer, Timer
|
||||||
from .tensor_detector import TensorDetector
|
from .tensor_detector import TensorDetector
|
||||||
|
|
||||||
@ -17,7 +17,8 @@ __all__ = [
|
|||||||
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context',
|
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context',
|
||||||
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
||||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction',
|
||||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'get_dataloader',
|
||||||
|
'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
||||||
'ensure_path_exists', 'disposable'
|
'ensure_path_exists', 'disposable'
|
||||||
]
|
]
|
||||||
|
@ -20,13 +20,15 @@ def set_to_cuda(models):
|
|||||||
return models.to(get_current_device())
|
return models.to(get_current_device())
|
||||||
|
|
||||||
|
|
||||||
def get_current_device():
|
def get_current_device() -> torch.device:
|
||||||
"""Returns the index of a currently selected device (gpu/cpu).
|
"""
|
||||||
|
Returns currently selected device (gpu/cpu).
|
||||||
|
If cuda available, return gpu, otherwise return cpu.
|
||||||
"""
|
"""
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return torch.cuda.current_device()
|
return torch.device(f'cuda:{torch.cuda.current_device()}')
|
||||||
else:
|
else:
|
||||||
return 'cpu'
|
return torch.device('cpu')
|
||||||
|
|
||||||
|
|
||||||
def synchronize():
|
def synchronize():
|
||||||
|
147
colossalai/utils/memory.py
Normal file
147
colossalai/utils/memory.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import torch
|
||||||
|
import gc
|
||||||
|
import psutil
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def _bytes_to_MB(val, decimal=2):
|
||||||
|
"""A byte-to-Megabyte converter, default using binary notation.
|
||||||
|
|
||||||
|
:param val: X bytes to convert
|
||||||
|
:return: X' MB
|
||||||
|
"""
|
||||||
|
return round(val / (1024 * 1024), decimal)
|
||||||
|
|
||||||
|
|
||||||
|
# copy from PatrickStar
|
||||||
|
def _get_cpu_memory_info():
|
||||||
|
ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"])
|
||||||
|
try:
|
||||||
|
# psutil reads the memory info from /proc/memory_info,
|
||||||
|
# which results in returning the host memory instead of
|
||||||
|
# that of container.
|
||||||
|
# Here we try to read the container memory with method in:
|
||||||
|
# https://stackoverflow.com/a/46213331/5163915
|
||||||
|
mems = {}
|
||||||
|
with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f:
|
||||||
|
for line in f:
|
||||||
|
fields = line.split()
|
||||||
|
mems[fields[0]] = int(fields[1]) * 1024
|
||||||
|
total = mems[b"MemTotal:"]
|
||||||
|
free = mems[b"MemFree:"]
|
||||||
|
cached = mems[b"Cached:"]
|
||||||
|
buffers = mems[b"Buffers:"]
|
||||||
|
used = total - free - cached - buffers
|
||||||
|
if used < 0:
|
||||||
|
used = total - free
|
||||||
|
mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used)
|
||||||
|
except FileNotFoundError:
|
||||||
|
mems = psutil.virtual_memory()
|
||||||
|
mem_info = ps_mem_info(
|
||||||
|
total=mems.total,
|
||||||
|
free=mems.free,
|
||||||
|
cached=mems.cached,
|
||||||
|
buffers=mems.buffers,
|
||||||
|
used=mems.used,
|
||||||
|
)
|
||||||
|
return mem_info
|
||||||
|
|
||||||
|
|
||||||
|
def report_memory_usage(message, logger=None, report_cpu=False):
|
||||||
|
"""Calculate and print RAM usage (in GB)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): A prefix message to add in the log.
|
||||||
|
logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information.
|
||||||
|
report_cpu (bool, optional): Whether to report CPU memory.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EnvironmentError: Raise error if no distributed environment has been initialized.
|
||||||
|
"""
|
||||||
|
if not gpc.is_initialized(ParallelMode.GLOBAL):
|
||||||
|
raise EnvironmentError("No distributed environment is initialized")
|
||||||
|
|
||||||
|
gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated())
|
||||||
|
gpu_max_allocated = _bytes_to_MB(torch.cuda.max_memory_allocated())
|
||||||
|
gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved())
|
||||||
|
gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved())
|
||||||
|
|
||||||
|
full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
|
||||||
|
+ f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
|
||||||
|
|
||||||
|
if report_cpu:
|
||||||
|
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
|
||||||
|
gc.collect()
|
||||||
|
vm_stats = psutil.virtual_memory()
|
||||||
|
vm_used = _bytes_to_MB(vm_stats.total - vm_stats.available)
|
||||||
|
full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%"
|
||||||
|
|
||||||
|
if logger is None:
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.info(full_log)
|
||||||
|
|
||||||
|
# get the peak memory to report correct data, so reset the counter for the next call
|
||||||
|
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
|
def colo_device_memory_capacity(device: torch.device) -> int:
|
||||||
|
"""
|
||||||
|
Get the capacity of the memory of the device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device): a device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: size in byte
|
||||||
|
"""
|
||||||
|
assert isinstance(device, torch.device)
|
||||||
|
if device.type == 'cpu':
|
||||||
|
mem_info = _get_cpu_memory_info()
|
||||||
|
return mem_info.info.total / gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
if device.type == 'cuda':
|
||||||
|
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
|
||||||
|
|
||||||
|
|
||||||
|
def colo_device_memory_used(device: torch.device) -> int:
|
||||||
|
"""
|
||||||
|
Get the device memory on device belonging to the current process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device): a device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: memory size in bytes
|
||||||
|
"""
|
||||||
|
if device.type == 'cpu':
|
||||||
|
mem_info = _get_cpu_memory_info()
|
||||||
|
# FIXME(jiaruifang) we need get how many processes are using the CPU memory.
|
||||||
|
ret = mem_info.used / gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
return ret
|
||||||
|
elif device.type == 'cuda':
|
||||||
|
ret: int = torch.cuda.memory_allocated(device)
|
||||||
|
# get the peak memory to report correct data, so reset the counter for the next call
|
||||||
|
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||||
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def colo_set_process_memory_fraction(ratio: float) -> None:
|
||||||
|
"""colo_set_process_memory_fraction
|
||||||
|
|
||||||
|
set how much cuda memory used on the gpu belonging to the current process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ratio (float): a ratio between 0. ~ 1.
|
||||||
|
"""
|
||||||
|
global _GLOBAL_CUDA_MEM_FRACTION
|
||||||
|
_GLOBAL_CUDA_MEM_FRACTION = ratio
|
||||||
|
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
|
@ -4,7 +4,7 @@ import pickle
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.utils.memory_utils.utils import colo_device_memory_used
|
from colossalai.utils.memory import colo_device_memory_used
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory_utils.utils import colo_device_memory_used
|
from colossalai.utils.memory import colo_device_memory_used
|
||||||
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
|
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
import torch
|
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
class BucketizedTensorCopy(object):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
chunk_size: int,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
torch.nn.Parameter CPU (fp32) -> ShardedParam GPU (fp16)
|
|
||||||
TODO(jiaruifang) The class is a little bit hardcoded
|
|
||||||
I will make it more general later.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self._offset = 0
|
|
||||||
self._cpu_buffer = torch.empty(chunk_size, dtype=torch.float, device=torch.device("cpu:0"), pin_memory=True)
|
|
||||||
self._cuda_buffer = torch.empty(chunk_size,
|
|
||||||
dtype=torch.half,
|
|
||||||
device=torch.device(f"cuda:{get_current_device()}"))
|
|
||||||
|
|
||||||
self._buffered_param_list: List[ShardedParamV2] = []
|
|
||||||
self._numel_list = []
|
|
||||||
|
|
||||||
def copy(self, src_param: torch.nn.Parameter, target_param: ShardedParamV2):
|
|
||||||
assert isinstance(target_param, ShardedParamV2)
|
|
||||||
assert isinstance(src_param, torch.nn.Parameter)
|
|
||||||
|
|
||||||
numel = src_param.numel()
|
|
||||||
|
|
||||||
if self._offset + numel > self.chunk_size:
|
|
||||||
self.flush()
|
|
||||||
|
|
||||||
assert src_param.data.device.type == 'cpu'
|
|
||||||
self._cpu_buffer.narrow(0, self._offset, numel).copy_(src_param.data.view(-1))
|
|
||||||
|
|
||||||
self._buffered_param_list.append(target_param)
|
|
||||||
self._numel_list.append(numel)
|
|
||||||
|
|
||||||
self._offset += numel
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
"""
|
|
||||||
flush to cuda memory
|
|
||||||
"""
|
|
||||||
self._cuda_buffer.copy_(self._cpu_buffer)
|
|
||||||
flush_offset = 0
|
|
||||||
for sparam, numel in zip(self._buffered_param_list, self._numel_list):
|
|
||||||
sparam.sharded_data_tensor.copy_payload(self._cpu_buffer.narrow(0, flush_offset, numel))
|
|
||||||
flush_offset += numel
|
|
||||||
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._buffered_param_list = []
|
|
||||||
self._numel_list = []
|
|
||||||
self._offset = 0
|
|
@ -1,67 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
import gc
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_GB(val, decimal=2):
|
|
||||||
"""A byte-to-Gigabyte converter, default using binary notation.
|
|
||||||
|
|
||||||
:param val: X bytes to convert
|
|
||||||
:return: X' GB
|
|
||||||
"""
|
|
||||||
return round(val / (1024 * 1024 * 1024), decimal)
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_MB(val, decimal=2):
|
|
||||||
"""A byte-to-Megabyte converter, default using binary notation.
|
|
||||||
|
|
||||||
:param val: X bytes to convert
|
|
||||||
:return: X' MB
|
|
||||||
"""
|
|
||||||
return round(val / (1024 * 1024), decimal)
|
|
||||||
|
|
||||||
|
|
||||||
def report_memory_usage(message, logger=None, report_cpu=False):
|
|
||||||
"""Calculate and print RAM usage (in GB)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message (str): A prefix message to add in the log.
|
|
||||||
logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information.
|
|
||||||
report_cpu (bool, optional): Whether to report CPU memory.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
EnvironmentError: Raise error if no distributed environment has been initialized.
|
|
||||||
"""
|
|
||||||
if not gpc.is_initialized(ParallelMode.GLOBAL):
|
|
||||||
raise EnvironmentError("No distributed environment is initialized")
|
|
||||||
|
|
||||||
gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated())
|
|
||||||
gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated())
|
|
||||||
gpu_cached = bytes_to_MB(torch.cuda.memory_reserved())
|
|
||||||
gpu_max_cached = bytes_to_MB(torch.cuda.max_memory_reserved())
|
|
||||||
|
|
||||||
full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
|
|
||||||
+ f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
|
|
||||||
|
|
||||||
if report_cpu:
|
|
||||||
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
|
|
||||||
gc.collect()
|
|
||||||
vm_stats = psutil.virtual_memory()
|
|
||||||
vm_used = bytes_to_MB(vm_stats.total - vm_stats.available)
|
|
||||||
full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%"
|
|
||||||
|
|
||||||
if logger is None:
|
|
||||||
logger = get_dist_logger()
|
|
||||||
logger.info(full_log)
|
|
||||||
|
|
||||||
# get the peak memory to report correct data, so reset the counter for the next call
|
|
||||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
@ -1,82 +0,0 @@
|
|||||||
import torch
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
import psutil
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
|
|
||||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
|
||||||
|
|
||||||
|
|
||||||
# copy from PatrickStar
|
|
||||||
def _get_cpu_memory_info():
|
|
||||||
ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"])
|
|
||||||
try:
|
|
||||||
# psutil reads the memory info from /proc/memory_info,
|
|
||||||
# which results in returning the host memory instead of
|
|
||||||
# that of container.
|
|
||||||
# Here we try to read the container memory with method in:
|
|
||||||
# https://stackoverflow.com/a/46213331/5163915
|
|
||||||
mems = {}
|
|
||||||
with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f:
|
|
||||||
for line in f:
|
|
||||||
fields = line.split()
|
|
||||||
mems[fields[0]] = int(fields[1]) * 1024
|
|
||||||
total = mems[b"MemTotal:"]
|
|
||||||
free = mems[b"MemFree:"]
|
|
||||||
cached = mems[b"Cached:"]
|
|
||||||
buffers = mems[b"Buffers:"]
|
|
||||||
used = total - free - cached - buffers
|
|
||||||
if used < 0:
|
|
||||||
used = total - free
|
|
||||||
mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used)
|
|
||||||
except FileNotFoundError:
|
|
||||||
mems = psutil.virtual_memory()
|
|
||||||
mem_info = ps_mem_info(
|
|
||||||
total=mems.total,
|
|
||||||
free=mems.free,
|
|
||||||
cached=mems.cached,
|
|
||||||
buffers=mems.buffers,
|
|
||||||
used=mems.used,
|
|
||||||
)
|
|
||||||
return mem_info
|
|
||||||
|
|
||||||
|
|
||||||
def colo_device_memory_used(device) -> int:
|
|
||||||
if not isinstance(device, torch.device):
|
|
||||||
device = torch.device(f"cuda:{device}")
|
|
||||||
if device.type == 'cpu':
|
|
||||||
mem_info = _get_cpu_memory_info()
|
|
||||||
# FIXME(jiaruifang) only work for 1-CPU multi-GPU
|
|
||||||
# CPU memory is sharded with all processes
|
|
||||||
# Not support multi-GPU multi-CPU
|
|
||||||
# We need a local_world_size here
|
|
||||||
ret = mem_info.used / gpc.get_world_size(ParallelMode.DATA)
|
|
||||||
return ret
|
|
||||||
elif device.type == 'cuda':
|
|
||||||
ret: int = torch.cuda.memory_allocated(device)
|
|
||||||
# get the peak memory to report correct data, so reset the counter for the next call
|
|
||||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
|
||||||
torch.cuda.reset_peak_memory_stats(device)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def colo_set_process_memory_fraction(ratio: float) -> None:
|
|
||||||
"""colo_set_process_memory_fraction
|
|
||||||
|
|
||||||
set how much cuda memory used on the gpu belonging to the current process.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ratio (float): a ratio between 0. ~ 1.
|
|
||||||
"""
|
|
||||||
global _GLOBAL_CUDA_MEM_FRACTION
|
|
||||||
_GLOBAL_CUDA_MEM_FRACTION = ratio
|
|
||||||
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
|
|
||||||
|
|
||||||
|
|
||||||
def colo_cuda_memory_capacity() -> float:
|
|
||||||
"""
|
|
||||||
Get cuda memory capacity of the current cuda.
|
|
||||||
"""
|
|
||||||
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
|
|
@ -5,7 +5,7 @@ from colossalai.utils.cuda import get_current_device
|
|||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||||
@ -64,7 +64,7 @@ class StatefulTensorMgr(object):
|
|||||||
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
|
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
cuda_capacity = colo_cuda_memory_capacity()
|
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||||
|
|
||||||
if self._warmup:
|
if self._warmup:
|
||||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||||
|
@ -33,7 +33,7 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||||||
if t.is_sharded:
|
if t.is_sharded:
|
||||||
return
|
return
|
||||||
if t.payload.device.type == 'cuda':
|
if t.payload.device.type == 'cuda':
|
||||||
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||||
f" but current cuda device is {get_current_device()}"
|
f" but current cuda device is {get_current_device()}"
|
||||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||||
t.reset_payload(sharded_payload)
|
t.reset_payload(sharded_payload)
|
||||||
|
@ -16,7 +16,7 @@ from colossalai.utils import get_current_device, disposable
|
|||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu
|
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu
|
||||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||||
@ -231,7 +231,7 @@ class ShardedModelV2(nn.Module):
|
|||||||
# the way to calculate margin space is based on the assumption that
|
# the way to calculate margin space is based on the assumption that
|
||||||
# model data is fixed in cuda during training.
|
# model data is fixed in cuda during training.
|
||||||
# cuda margin space can be used to store OS.
|
# cuda margin space can be used to store OS.
|
||||||
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
|
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
|
||||||
self._memstats_collector.overall_mem_stats('cuda'))
|
self._memstats_collector.overall_mem_stats('cuda'))
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -41,7 +41,7 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
|
|||||||
logger = get_dist_logger("test_moe_zero_init")
|
logger = get_dist_logger("test_moe_zero_init")
|
||||||
|
|
||||||
if init_device_type == 'cuda':
|
if init_device_type == 'cuda':
|
||||||
init_device = torch.device(f"cuda:{get_current_device()}")
|
init_device = get_current_device()
|
||||||
elif init_device_type == 'cpu':
|
elif init_device_type == 'cpu':
|
||||||
init_device = torch.device("cpu")
|
init_device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
|
@ -62,8 +62,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
|||||||
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
||||||
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
with ZeroInitContext(
|
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
|
||||||
target_device=torch.device('cpu') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
|
|
||||||
shard_strategy=shard_strategy,
|
shard_strategy=shard_strategy,
|
||||||
shard_param=True):
|
shard_param=True):
|
||||||
zero_model = MoeModel()
|
zero_model = MoeModel()
|
||||||
|
@ -1,39 +0,0 @@
|
|||||||
from colossalai.utils.memory_utils.bucket_tensor_copy import BucketizedTensorCopy
|
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
import torch
|
|
||||||
import colossalai
|
|
||||||
|
|
||||||
|
|
||||||
def test_bucket_copy():
|
|
||||||
# init dist env
|
|
||||||
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
|
||||||
|
|
||||||
copyer = BucketizedTensorCopy(20)
|
|
||||||
|
|
||||||
shape_list = [(2, 3), (5), (8), (12)]
|
|
||||||
src_param_list = []
|
|
||||||
tgt_param_list = []
|
|
||||||
for shape in shape_list:
|
|
||||||
# on CPU
|
|
||||||
src_param = torch.nn.Parameter(torch.randn(shape, dtype=torch.float, device=torch.device('cpu')))
|
|
||||||
# on GPU
|
|
||||||
tgt_param = ShardedParamV2(torch.nn.Parameter(torch.ones(shape, dtype=torch.half, device=torch.device('cuda'))))
|
|
||||||
|
|
||||||
src_param_list.append(src_param)
|
|
||||||
tgt_param_list.append(tgt_param)
|
|
||||||
|
|
||||||
copyer.copy(src_param, tgt_param)
|
|
||||||
|
|
||||||
copyer.flush()
|
|
||||||
|
|
||||||
for src_param, tgt_param in zip(src_param_list, tgt_param_list):
|
|
||||||
diff = src_param.cpu().float() - tgt_param.sharded_data_tensor.payload.cpu().float()
|
|
||||||
assert torch.allclose(src_param.cpu().float(),
|
|
||||||
tgt_param.sharded_data_tensor.payload.cpu().float(),
|
|
||||||
rtol=1e-03,
|
|
||||||
atol=1e-03), f"diff {diff}"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_bucket_copy()
|
|
@ -2,7 +2,7 @@ import pytest
|
|||||||
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage, colo_model_data_tensor_move, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_tensor_clone
|
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage, colo_model_data_tensor_move, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_tensor_clone
|
||||||
from colossalai.utils.memory_utils.utils import colo_set_process_memory_fraction, colo_cuda_memory_capacity
|
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||||
import colossalai
|
import colossalai
|
||||||
@ -12,6 +12,7 @@ import torch
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
|
||||||
def _run_colo_tensor_mem_usage():
|
def _run_colo_tensor_mem_usage():
|
||||||
for i in range(1):
|
for i in range(1):
|
||||||
if i == 1:
|
if i == 1:
|
||||||
@ -29,37 +30,45 @@ def _run_colo_tensor_mem_usage():
|
|||||||
assert c1 * 4 == c2
|
assert c1 * 4 == c2
|
||||||
assert g1 * 4 == g2
|
assert g1 * 4 == g2
|
||||||
|
|
||||||
def _run_colo_set_process_memory_fraction_and_colo_cuda_memory_capacity():
|
|
||||||
frac1 = colo_cuda_memory_capacity()
|
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
|
||||||
|
frac1 = colo_device_memory_capacity(get_current_device())
|
||||||
colo_set_process_memory_fraction(0.5)
|
colo_set_process_memory_fraction(0.5)
|
||||||
frac2 = colo_cuda_memory_capacity()
|
frac2 = colo_device_memory_capacity(get_current_device())
|
||||||
assert frac2 * 2 == frac1
|
assert frac2 * 2 == frac1
|
||||||
|
|
||||||
|
|
||||||
def _run_colo_model_data_tensor_move_inline():
|
def _run_colo_model_data_tensor_move_inline():
|
||||||
for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]:
|
for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]:
|
||||||
colo_model_data_tensor_move_inline(t, torch.device(f"cuda:{get_current_device()}"))
|
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||||
assert t.device == torch.device(f"cuda:{get_current_device()}")
|
assert t.device == get_current_device()
|
||||||
|
|
||||||
|
|
||||||
def _run_colo_model_data_tensor_move():
|
def _run_colo_model_data_tensor_move():
|
||||||
for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).cuda(get_current_device()))),
|
for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))),
|
||||||
(torch.ones(2, 3), torch.zeros(2, 3).cuda(get_current_device()))]:
|
(torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]:
|
||||||
cpu_t, cuda_t = t
|
cpu_t, cuda_t = t
|
||||||
colo_model_data_tensor_move(cpu_t, cuda_t)
|
colo_model_data_tensor_move(cpu_t, cuda_t)
|
||||||
assert cuda_t.device == torch.device(f"cuda:{get_current_device()}")
|
assert cuda_t.device == get_current_device()
|
||||||
|
|
||||||
|
|
||||||
def _run_colo_model_data_move_to_cpu():
|
def _run_colo_model_data_move_to_cpu():
|
||||||
for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]:
|
for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]:
|
||||||
colo_model_data_move_to_cpu(t)
|
colo_model_data_move_to_cpu(t)
|
||||||
assert t.device == torch.device("cpu")
|
assert t.device == torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
def _run_colo_model_tensor_clone():
|
def _run_colo_model_tensor_clone():
|
||||||
for t in [StatefulTensor(torch.randn(2,2).cuda(torch.cuda.current_device())), torch.randn(4,4).cuda(torch.cuda.current_device())]:
|
for t in [
|
||||||
|
StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())),
|
||||||
|
torch.randn(4, 4).cuda(torch.cuda.current_device())
|
||||||
|
]:
|
||||||
if issubclass(type(t), StatefulTensor):
|
if issubclass(type(t), StatefulTensor):
|
||||||
assert t.payload.device == torch.device(f"cuda:{get_current_device()}")
|
assert t.payload.device == get_current_device()
|
||||||
else:
|
else:
|
||||||
assert t.device == torch.device(f"cuda:{get_current_device()}")
|
assert t.device == get_current_device()
|
||||||
p = colo_model_tensor_clone(t, torch.device(f"cuda:{get_current_device()}"))
|
p = colo_model_tensor_clone(t, get_current_device())
|
||||||
assert p.device == torch.device(f"cuda:{get_current_device()}")
|
assert p.device == get_current_device()
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
if issubclass(type(t), StatefulTensor):
|
if issubclass(type(t), StatefulTensor):
|
||||||
@ -70,21 +79,22 @@ def _run_colo_model_tensor_clone():
|
|||||||
assert t[i][j] == p[i][j]
|
assert t[i][j] == p[i][j]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
_run_colo_set_process_memory_fraction_and_colo_cuda_memory_capacity()
|
_run_colo_set_process_memory_fraction_and_colo_device_memory_capacity()
|
||||||
_run_colo_model_data_tensor_move_inline()
|
_run_colo_model_data_tensor_move_inline()
|
||||||
_run_colo_model_data_tensor_move()
|
_run_colo_model_data_tensor_move()
|
||||||
_run_colo_tensor_mem_usage()
|
_run_colo_tensor_mem_usage()
|
||||||
_run_colo_model_data_move_to_cpu()
|
_run_colo_model_data_move_to_cpu()
|
||||||
_run_colo_model_tensor_clone()
|
_run_colo_model_tensor_clone()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4, 5])
|
@pytest.mark.parametrize("world_size", [4, 5])
|
||||||
def test_tensor_move(world_size):
|
def test_tensor_move(world_size):
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_tensor_move(4)
|
test_tensor_move(4)
|
||||||
|
@ -13,7 +13,7 @@ from colossalai.utils import free_port
|
|||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
colo_model_mem_usage
|
colo_model_mem_usage
|
||||||
from colossalai.utils.memory_utils.utils import colo_device_memory_used
|
from colossalai.utils.memory import colo_device_memory_used
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
@ -29,7 +29,7 @@ def run_model_test(init_device_type, shard_strategy_class):
|
|||||||
for get_components_func in non_distributed_component_funcs:
|
for get_components_func in non_distributed_component_funcs:
|
||||||
model_builder, _, _, _, _ = get_components_func()
|
model_builder, _, _, _, _ = get_components_func()
|
||||||
if init_device_type == 'cuda':
|
if init_device_type == 'cuda':
|
||||||
init_device = torch.device(f"cuda:{get_current_device()}")
|
init_device = get_current_device()
|
||||||
elif init_device_type == 'cpu':
|
elif init_device_type == 'cpu':
|
||||||
init_device = torch.device("cpu")
|
init_device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
|
@ -57,8 +57,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
|||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
with ZeroInitContext(
|
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
|
||||||
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
|
|
||||||
shard_strategy=shard_strategy,
|
shard_strategy=shard_strategy,
|
||||||
shard_param=True):
|
shard_param=True):
|
||||||
zero_model = model_builder(checkpoint=True)
|
zero_model = model_builder(checkpoint=True)
|
||||||
|
@ -2,9 +2,10 @@ import torch
|
|||||||
import colossalai
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity, colo_set_process_memory_fraction
|
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||||
from colossalai.zero.shard_utils import StatefulTensorMgr
|
from colossalai.zero.shard_utils import StatefulTensorMgr
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
@ -26,7 +27,7 @@ class Net(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def run_stm():
|
def run_stm():
|
||||||
cuda_capacity = colo_cuda_memory_capacity()
|
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||||
fraction = (1.4 * 1024**3) / cuda_capacity
|
fraction = (1.4 * 1024**3) / cuda_capacity
|
||||||
# limit max memory to 1.4GB
|
# limit max memory to 1.4GB
|
||||||
# which means only 2 parameters can be on CUDA
|
# which means only 2 parameters can be on CUDA
|
||||||
|
Loading…
Reference in New Issue
Block a user