diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index a43cf1878..4dfe924dd 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -30,7 +30,7 @@ class ZeroHook(BaseOpHook): self.process_group = process_group # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU - self.computing_device = torch.device(f'cuda:{get_current_device()}') + self.computing_device = get_current_device() self._memstarts_collector = memstarts_collector self._stateful_tensor_mgr = stateful_tensor_mgr diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 5e46d215a..2fd5a7951 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -8,7 +8,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral sync_model_param, disposable) from .data_sampler import DataParallelSampler, get_dataloader from .gradient_accumulation import accumulate_gradient -from .memory_utils.memory_monitor import report_memory_usage +from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity from .timer import MultiTimer, Timer from .tensor_detector import TensorDetector @@ -17,7 +17,8 @@ __all__ = [ 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', - 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', - 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint', + 'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction', + 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'get_dataloader', + 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint', 'ensure_path_exists', 'disposable' ] diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py index b287fa276..60f3ccb60 100644 --- a/colossalai/utils/cuda.py +++ b/colossalai/utils/cuda.py @@ -20,13 +20,15 @@ def set_to_cuda(models): return models.to(get_current_device()) -def get_current_device(): - """Returns the index of a currently selected device (gpu/cpu). +def get_current_device() -> torch.device: + """ + Returns currently selected device (gpu/cpu). + If cuda available, return gpu, otherwise return cpu. """ if torch.cuda.is_available(): - return torch.cuda.current_device() + return torch.device(f'cuda:{torch.cuda.current_device()}') else: - return 'cpu' + return torch.device('cpu') def synchronize(): diff --git a/colossalai/utils/memory.py b/colossalai/utils/memory.py new file mode 100644 index 000000000..799e68f4d --- /dev/null +++ b/colossalai/utils/memory.py @@ -0,0 +1,147 @@ +import torch +import gc +import psutil +from collections import namedtuple + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.utils import get_current_device +from colossalai.core import global_context as gpc +from colossalai.context.parallel_mode import ParallelMode +from colossalai.logging import get_dist_logger + +_GLOBAL_CUDA_MEM_FRACTION = 1.0 + + +def _bytes_to_MB(val, decimal=2): + """A byte-to-Megabyte converter, default using binary notation. + + :param val: X bytes to convert + :return: X' MB + """ + return round(val / (1024 * 1024), decimal) + + +# copy from PatrickStar +def _get_cpu_memory_info(): + ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"]) + try: + # psutil reads the memory info from /proc/memory_info, + # which results in returning the host memory instead of + # that of container. + # Here we try to read the container memory with method in: + # https://stackoverflow.com/a/46213331/5163915 + mems = {} + with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f: + for line in f: + fields = line.split() + mems[fields[0]] = int(fields[1]) * 1024 + total = mems[b"MemTotal:"] + free = mems[b"MemFree:"] + cached = mems[b"Cached:"] + buffers = mems[b"Buffers:"] + used = total - free - cached - buffers + if used < 0: + used = total - free + mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used) + except FileNotFoundError: + mems = psutil.virtual_memory() + mem_info = ps_mem_info( + total=mems.total, + free=mems.free, + cached=mems.cached, + buffers=mems.buffers, + used=mems.used, + ) + return mem_info + + +def report_memory_usage(message, logger=None, report_cpu=False): + """Calculate and print RAM usage (in GB) + + Args: + message (str): A prefix message to add in the log. + logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information. + report_cpu (bool, optional): Whether to report CPU memory. + + Raises: + EnvironmentError: Raise error if no distributed environment has been initialized. + """ + if not gpc.is_initialized(ParallelMode.GLOBAL): + raise EnvironmentError("No distributed environment is initialized") + + gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated()) + gpu_max_allocated = _bytes_to_MB(torch.cuda.max_memory_allocated()) + gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved()) + gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved()) + + full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \ + + f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB" + + if report_cpu: + # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports + gc.collect() + vm_stats = psutil.virtual_memory() + vm_used = _bytes_to_MB(vm_stats.total - vm_stats.available) + full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%" + + if logger is None: + logger = get_dist_logger() + logger.info(full_log) + + # get the peak memory to report correct data, so reset the counter for the next call + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + torch.cuda.reset_peak_memory_stats() + + +def colo_device_memory_capacity(device: torch.device) -> int: + """ + Get the capacity of the memory of the device + + Args: + device (torch.device): a device + + Returns: + int: size in byte + """ + assert isinstance(device, torch.device) + if device.type == 'cpu': + mem_info = _get_cpu_memory_info() + return mem_info.info.total / gpc.get_world_size(ParallelMode.DATA) + if device.type == 'cuda': + return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION + + +def colo_device_memory_used(device: torch.device) -> int: + """ + Get the device memory on device belonging to the current process. + + Args: + device (torch.device): a device + + Returns: + int: memory size in bytes + """ + if device.type == 'cpu': + mem_info = _get_cpu_memory_info() + # FIXME(jiaruifang) we need get how many processes are using the CPU memory. + ret = mem_info.used / gpc.get_world_size(ParallelMode.DATA) + return ret + elif device.type == 'cuda': + ret: int = torch.cuda.memory_allocated(device) + # get the peak memory to report correct data, so reset the counter for the next call + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + torch.cuda.reset_peak_memory_stats(device) + return ret + + +def colo_set_process_memory_fraction(ratio: float) -> None: + """colo_set_process_memory_fraction + + set how much cuda memory used on the gpu belonging to the current process. + + Args: + ratio (float): a ratio between 0. ~ 1. + """ + global _GLOBAL_CUDA_MEM_FRACTION + _GLOBAL_CUDA_MEM_FRACTION = ratio + torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) diff --git a/colossalai/utils/memory_tracer/async_memtracer.py b/colossalai/utils/memory_tracer/async_memtracer.py index 7ceaf9d80..3be66cd9c 100644 --- a/colossalai/utils/memory_tracer/async_memtracer.py +++ b/colossalai/utils/memory_tracer/async_memtracer.py @@ -4,7 +4,7 @@ import pickle import torch -from colossalai.utils.memory_utils.utils import colo_device_memory_used +from colossalai.utils.memory import colo_device_memory_used from colossalai.utils import get_current_device diff --git a/colossalai/utils/memory_tracer/memstats_collector.py b/colossalai/utils/memory_tracer/memstats_collector.py index 1c93998cc..5da971ab5 100644 --- a/colossalai/utils/memory_tracer/memstats_collector.py +++ b/colossalai/utils/memory_tracer/memstats_collector.py @@ -1,5 +1,5 @@ from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from colossalai.utils.memory_utils.utils import colo_device_memory_used +from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor import torch import time diff --git a/colossalai/utils/memory_utils/__init__.py b/colossalai/utils/memory_utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/colossalai/utils/memory_utils/bucket_tensor_copy.py b/colossalai/utils/memory_utils/bucket_tensor_copy.py deleted file mode 100644 index f65a75a81..000000000 --- a/colossalai/utils/memory_utils/bucket_tensor_copy.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -from colossalai.zero.sharded_param import ShardedParamV2 -from colossalai.utils import get_current_device -from typing import List - - -class BucketizedTensorCopy(object): - - def __init__( - self, - chunk_size: int, - ): - r""" - torch.nn.Parameter CPU (fp32) -> ShardedParam GPU (fp16) - TODO(jiaruifang) The class is a little bit hardcoded - I will make it more general later. - """ - - self.chunk_size = chunk_size - self._offset = 0 - self._cpu_buffer = torch.empty(chunk_size, dtype=torch.float, device=torch.device("cpu:0"), pin_memory=True) - self._cuda_buffer = torch.empty(chunk_size, - dtype=torch.half, - device=torch.device(f"cuda:{get_current_device()}")) - - self._buffered_param_list: List[ShardedParamV2] = [] - self._numel_list = [] - - def copy(self, src_param: torch.nn.Parameter, target_param: ShardedParamV2): - assert isinstance(target_param, ShardedParamV2) - assert isinstance(src_param, torch.nn.Parameter) - - numel = src_param.numel() - - if self._offset + numel > self.chunk_size: - self.flush() - - assert src_param.data.device.type == 'cpu' - self._cpu_buffer.narrow(0, self._offset, numel).copy_(src_param.data.view(-1)) - - self._buffered_param_list.append(target_param) - self._numel_list.append(numel) - - self._offset += numel - - def flush(self): - """ - flush to cuda memory - """ - self._cuda_buffer.copy_(self._cpu_buffer) - flush_offset = 0 - for sparam, numel in zip(self._buffered_param_list, self._numel_list): - sparam.sharded_data_tensor.copy_payload(self._cpu_buffer.narrow(0, flush_offset, numel)) - flush_offset += numel - - self.reset() - - def reset(self): - self._buffered_param_list = [] - self._numel_list = [] - self._offset = 0 diff --git a/colossalai/utils/memory_utils/memory_monitor.py b/colossalai/utils/memory_utils/memory_monitor.py deleted file mode 100644 index 38f6349a3..000000000 --- a/colossalai/utils/memory_utils/memory_monitor.py +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import gc - -import psutil -import torch - -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger - - -def bytes_to_GB(val, decimal=2): - """A byte-to-Gigabyte converter, default using binary notation. - - :param val: X bytes to convert - :return: X' GB - """ - return round(val / (1024 * 1024 * 1024), decimal) - - -def bytes_to_MB(val, decimal=2): - """A byte-to-Megabyte converter, default using binary notation. - - :param val: X bytes to convert - :return: X' MB - """ - return round(val / (1024 * 1024), decimal) - - -def report_memory_usage(message, logger=None, report_cpu=False): - """Calculate and print RAM usage (in GB) - - Args: - message (str): A prefix message to add in the log. - logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information. - report_cpu (bool, optional): Whether to report CPU memory. - - Raises: - EnvironmentError: Raise error if no distributed environment has been initialized. - """ - if not gpc.is_initialized(ParallelMode.GLOBAL): - raise EnvironmentError("No distributed environment is initialized") - - gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated()) - gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated()) - gpu_cached = bytes_to_MB(torch.cuda.memory_reserved()) - gpu_max_cached = bytes_to_MB(torch.cuda.max_memory_reserved()) - - full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \ - + f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB" - - if report_cpu: - # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports - gc.collect() - vm_stats = psutil.virtual_memory() - vm_used = bytes_to_MB(vm_stats.total - vm_stats.available) - full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%" - - if logger is None: - logger = get_dist_logger() - logger.info(full_log) - - # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ - torch.cuda.reset_peak_memory_stats() diff --git a/colossalai/utils/memory_utils/utils.py b/colossalai/utils/memory_utils/utils.py deleted file mode 100644 index ca0e974bc..000000000 --- a/colossalai/utils/memory_utils/utils.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -from colossalai.context.parallel_mode import ParallelMode -from colossalai.utils import get_current_device - -from collections import namedtuple -import psutil -from colossalai.core import global_context as gpc - -_GLOBAL_CUDA_MEM_FRACTION = 1.0 - - -# copy from PatrickStar -def _get_cpu_memory_info(): - ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"]) - try: - # psutil reads the memory info from /proc/memory_info, - # which results in returning the host memory instead of - # that of container. - # Here we try to read the container memory with method in: - # https://stackoverflow.com/a/46213331/5163915 - mems = {} - with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f: - for line in f: - fields = line.split() - mems[fields[0]] = int(fields[1]) * 1024 - total = mems[b"MemTotal:"] - free = mems[b"MemFree:"] - cached = mems[b"Cached:"] - buffers = mems[b"Buffers:"] - used = total - free - cached - buffers - if used < 0: - used = total - free - mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used) - except FileNotFoundError: - mems = psutil.virtual_memory() - mem_info = ps_mem_info( - total=mems.total, - free=mems.free, - cached=mems.cached, - buffers=mems.buffers, - used=mems.used, - ) - return mem_info - - -def colo_device_memory_used(device) -> int: - if not isinstance(device, torch.device): - device = torch.device(f"cuda:{device}") - if device.type == 'cpu': - mem_info = _get_cpu_memory_info() - # FIXME(jiaruifang) only work for 1-CPU multi-GPU - # CPU memory is sharded with all processes - # Not support multi-GPU multi-CPU - # We need a local_world_size here - ret = mem_info.used / gpc.get_world_size(ParallelMode.DATA) - return ret - elif device.type == 'cuda': - ret: int = torch.cuda.memory_allocated(device) - # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ - torch.cuda.reset_peak_memory_stats(device) - return ret - - -def colo_set_process_memory_fraction(ratio: float) -> None: - """colo_set_process_memory_fraction - - set how much cuda memory used on the gpu belonging to the current process. - - Args: - ratio (float): a ratio between 0. ~ 1. - """ - global _GLOBAL_CUDA_MEM_FRACTION - _GLOBAL_CUDA_MEM_FRACTION = ratio - torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) - - -def colo_cuda_memory_capacity() -> float: - """ - Get cuda memory capacity of the current cuda. - """ - return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION diff --git a/colossalai/zero/shard_utils/stateful_tensor_mgr.py b/colossalai/zero/shard_utils/stateful_tensor_mgr.py index 877c0763c..817a383d8 100644 --- a/colossalai/zero/shard_utils/stateful_tensor_mgr.py +++ b/colossalai/zero/shard_utils/stateful_tensor_mgr.py @@ -5,7 +5,7 @@ from colossalai.utils.cuda import get_current_device from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity +from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from typing import Dict, List from colossalai.utils.memory_tracer import MemStatsCollector @@ -64,7 +64,7 @@ class StatefulTensorMgr(object): cuda_demand += colo_tensor_mem_usage(tensor.payload)[1] else: raise RuntimeError - cuda_capacity = colo_cuda_memory_capacity() + cuda_capacity = colo_device_memory_capacity(get_current_device()) if self._warmup: # We designate a part of CUDA memory for model data in warmup iterations. diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 01803fed0..761143cf3 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -33,7 +33,7 @@ class TensorShardStrategy(BaseShardStrategy): if t.is_sharded: return if t.payload.device.type == 'cuda': - assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ + assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ f" but current cuda device is {get_current_device()}" sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.reset_payload(sharded_payload) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index ef2436b7f..2be5bac43 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -16,7 +16,7 @@ from colossalai.utils import get_current_device, disposable from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER -from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity +from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer @@ -231,7 +231,7 @@ class ShardedModelV2(nn.Module): # the way to calculate margin space is based on the assumption that # model data is fixed in cuda during training. # cuda margin space can be used to store OS. - self._cuda_margin_space = colo_cuda_memory_capacity() - max( + self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max( self._memstats_collector.overall_mem_stats('cuda')) @torch.no_grad() diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 309692285..4e62c77e9 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -41,7 +41,7 @@ def run_moe_zero_init(init_device_type, shard_strategy_class): logger = get_dist_logger("test_moe_zero_init") if init_device_type == 'cuda': - init_device = torch.device(f"cuda:{get_current_device()}") + init_device = get_current_device() elif init_device_type == 'cpu': init_device = torch.device("cpu") else: diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 8348e093c..7956f86f4 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -62,10 +62,9 @@ def _run_test_sharded_optim_v2(cpu_offload, get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module') _, train_dataloader, _, optimizer_class, criterion = get_components_func() - with ZeroInitContext( - target_device=torch.device('cpu') if cpu_offload else torch.device(f'cuda:{get_current_device()}'), - shard_strategy=shard_strategy, - shard_param=True): + with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): zero_model = MoeModel() zero_model = ShardedModelV2(zero_model, diff --git a/tests/test_utils/test_bucket_tensor_copy.py b/tests/test_utils/test_bucket_tensor_copy.py deleted file mode 100644 index f190cb522..000000000 --- a/tests/test_utils/test_bucket_tensor_copy.py +++ /dev/null @@ -1,39 +0,0 @@ -from colossalai.utils.memory_utils.bucket_tensor_copy import BucketizedTensorCopy -from colossalai.zero.sharded_param import ShardedParamV2 -from colossalai.utils import free_port -import torch -import colossalai - - -def test_bucket_copy(): - # init dist env - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') - - copyer = BucketizedTensorCopy(20) - - shape_list = [(2, 3), (5), (8), (12)] - src_param_list = [] - tgt_param_list = [] - for shape in shape_list: - # on CPU - src_param = torch.nn.Parameter(torch.randn(shape, dtype=torch.float, device=torch.device('cpu'))) - # on GPU - tgt_param = ShardedParamV2(torch.nn.Parameter(torch.ones(shape, dtype=torch.half, device=torch.device('cuda')))) - - src_param_list.append(src_param) - tgt_param_list.append(tgt_param) - - copyer.copy(src_param, tgt_param) - - copyer.flush() - - for src_param, tgt_param in zip(src_param_list, tgt_param_list): - diff = src_param.cpu().float() - tgt_param.sharded_data_tensor.payload.cpu().float() - assert torch.allclose(src_param.cpu().float(), - tgt_param.sharded_data_tensor.payload.cpu().float(), - rtol=1e-03, - atol=1e-03), f"diff {diff}" - - -if __name__ == '__main__': - test_bucket_copy() diff --git a/tests/test_utils/test_tensor_move.py b/tests/test_utils/test_tensor_move.py index 500e81f1f..62874d652 100644 --- a/tests/test_utils/test_tensor_move.py +++ b/tests/test_utils/test_tensor_move.py @@ -2,7 +2,7 @@ import pytest from colossalai.utils.cuda import get_current_device from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage, colo_model_data_tensor_move, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_tensor_clone -from colossalai.utils.memory_utils.utils import colo_set_process_memory_fraction, colo_cuda_memory_capacity +from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity from colossalai.utils import free_port from colossalai.zero.sharded_param.tensorful_state import StatefulTensor import colossalai @@ -12,54 +12,63 @@ import torch from functools import partial import torch.multiprocessing as mp + def _run_colo_tensor_mem_usage(): for i in range(1): if i == 1: - t1 = StatefulTensor(torch.randn(2,2)) - t2 = StatefulTensor(torch.randn(4,4)) - c1 , g1 = colo_tensor_mem_usage(t1) - c2 , g2 = colo_tensor_mem_usage(t2) - assert c1*4 == c2 - assert g1*4 == g2 + t1 = StatefulTensor(torch.randn(2, 2)) + t2 = StatefulTensor(torch.randn(4, 4)) + c1, g1 = colo_tensor_mem_usage(t1) + c2, g2 = colo_tensor_mem_usage(t2) + assert c1 * 4 == c2 + assert g1 * 4 == g2 else: - t1 = torch.randn(2,2) - t2 = torch.randn(4,4) - c1 , g1 = colo_tensor_mem_usage(t1) - c2 , g2 = colo_tensor_mem_usage(t2) - assert c1*4 == c2 - assert g1*4 == g2 + t1 = torch.randn(2, 2) + t2 = torch.randn(4, 4) + c1, g1 = colo_tensor_mem_usage(t1) + c2, g2 = colo_tensor_mem_usage(t2) + assert c1 * 4 == c2 + assert g1 * 4 == g2 -def _run_colo_set_process_memory_fraction_and_colo_cuda_memory_capacity(): - frac1 = colo_cuda_memory_capacity() + +def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): + frac1 = colo_device_memory_capacity(get_current_device()) colo_set_process_memory_fraction(0.5) - frac2 = colo_cuda_memory_capacity() - assert frac2*2 == frac1 + frac2 = colo_device_memory_capacity(get_current_device()) + assert frac2 * 2 == frac1 + def _run_colo_model_data_tensor_move_inline(): - for t in [StatefulTensor(torch.randn(2,3)), torch.randn(2,3)]: - colo_model_data_tensor_move_inline(t, torch.device(f"cuda:{get_current_device()}")) - assert t.device == torch.device(f"cuda:{get_current_device()}") + for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]: + colo_model_data_tensor_move_inline(t, get_current_device()) + assert t.device == get_current_device() + def _run_colo_model_data_tensor_move(): - for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).cuda(get_current_device()))), - (torch.ones(2, 3), torch.zeros(2, 3).cuda(get_current_device()))]: + for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))), + (torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]: cpu_t, cuda_t = t colo_model_data_tensor_move(cpu_t, cuda_t) - assert cuda_t.device == torch.device(f"cuda:{get_current_device()}") + assert cuda_t.device == get_current_device() + def _run_colo_model_data_move_to_cpu(): - for t in [StatefulTensor(torch.randn(2,2)), torch.randn(4,4)]: + for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]: colo_model_data_move_to_cpu(t) assert t.device == torch.device("cpu") + def _run_colo_model_tensor_clone(): - for t in [StatefulTensor(torch.randn(2,2).cuda(torch.cuda.current_device())), torch.randn(4,4).cuda(torch.cuda.current_device())]: + for t in [ + StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())), + torch.randn(4, 4).cuda(torch.cuda.current_device()) + ]: if issubclass(type(t), StatefulTensor): - assert t.payload.device == torch.device(f"cuda:{get_current_device()}") + assert t.payload.device == get_current_device() else: - assert t.device == torch.device(f"cuda:{get_current_device()}") - p = colo_model_tensor_clone(t, torch.device(f"cuda:{get_current_device()}")) - assert p.device == torch.device(f"cuda:{get_current_device()}") + assert t.device == get_current_device() + p = colo_model_tensor_clone(t, get_current_device()) + assert p.device == get_current_device() for i in range(2): for j in range(2): if issubclass(type(t), StatefulTensor): @@ -70,21 +79,22 @@ def _run_colo_model_tensor_clone(): assert t[i][j] == p[i][j] - def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_colo_set_process_memory_fraction_and_colo_cuda_memory_capacity() + _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity() _run_colo_model_data_tensor_move_inline() _run_colo_model_data_tensor_move() _run_colo_tensor_mem_usage() _run_colo_model_data_move_to_cpu() _run_colo_model_tensor_clone() + @pytest.mark.dist @pytest.mark.parametrize("world_size", [4, 5]) def test_tensor_move(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) + if __name__ == '__main__': test_tensor_move(4) diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index 094227b6d..34777c6b8 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -13,7 +13,7 @@ from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.memory_tracer.model_data_memtracer import \ colo_model_mem_usage -from colossalai.utils.memory_utils.utils import colo_device_memory_used +from colossalai.utils.memory import colo_device_memory_used from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from tests.components_to_test.registry import non_distributed_component_funcs @@ -29,7 +29,7 @@ def run_model_test(init_device_type, shard_strategy_class): for get_components_func in non_distributed_component_funcs: model_builder, _, _, _, _ = get_components_func() if init_device_type == 'cuda': - init_device = torch.device(f"cuda:{get_current_device()}") + init_device = get_current_device() elif init_device_type == 'cpu': init_device = torch.device("cpu") else: diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 01339cc34..0c8f8ea66 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -57,10 +57,9 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - with ZeroInitContext( - target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'), - shard_strategy=shard_strategy, - shard_param=True): + with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): zero_model = model_builder(checkpoint=True) zero_model = ShardedModelV2( zero_model, diff --git a/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py b/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py index 857dad5ea..af8165de2 100644 --- a/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py +++ b/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py @@ -2,9 +2,10 @@ import torch import colossalai import pytest import torch.multiprocessing as mp +from colossalai.utils.cuda import get_current_device from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity, colo_set_process_memory_fraction +from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.zero.shard_utils import StatefulTensorMgr from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import TensorState @@ -26,7 +27,7 @@ class Net(torch.nn.Module): def run_stm(): - cuda_capacity = colo_cuda_memory_capacity() + cuda_capacity = colo_device_memory_capacity(get_current_device()) fraction = (1.4 * 1024**3) / cuda_capacity # limit max memory to 1.4GB # which means only 2 parameters can be on CUDA