mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 15:57:16 +00:00
[zero] fix init device bug in zero init context unittest (#516)
This commit is contained in:
parent
a513164379
commit
0bebda6ea5
@ -2,26 +2,10 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from time import sleep, time
|
from time import sleep, time
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
def get_cuda_memory_used(device: torch.device) -> int:
|
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||||
"""
|
|
||||||
Get the free memory info of device.
|
|
||||||
:param device: device id
|
|
||||||
:type device: torch.device
|
|
||||||
:return: current memory usage, sized by MB
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert device.type == 'cuda'
|
|
||||||
|
|
||||||
ret: int = torch.cuda.memory_allocated(device)
|
|
||||||
# get the peak memory to report correct data, so reset the counter for the next call
|
|
||||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
|
||||||
torch.cuda.reset_peak_memory_stats(device)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncMemoryMonitor:
|
class AsyncMemoryMonitor:
|
||||||
@ -97,7 +81,7 @@ class AsyncMemoryMonitor:
|
|||||||
while self.keep_measuring:
|
while self.keep_measuring:
|
||||||
max_usage = max(
|
max_usage = max(
|
||||||
max_usage,
|
max_usage,
|
||||||
get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')),
|
colo_cuda_memory_used(),
|
||||||
)
|
)
|
||||||
sleep(self.interval)
|
sleep(self.interval)
|
||||||
return max_usage
|
return max_usage
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from .async_memtracer import get_cuda_memory_used
|
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -55,7 +55,7 @@ class MemStatsCollector:
|
|||||||
sampling_cnt = self._sampling_cnter.sampling_cnt
|
sampling_cnt = self._sampling_cnter.sampling_cnt
|
||||||
assert sampling_cnt == len(self._overall_cuda)
|
assert sampling_cnt == len(self._overall_cuda)
|
||||||
self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||||
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
||||||
self._sampling_cnter.advance()
|
self._sampling_cnter.advance()
|
||||||
|
|
||||||
def fetch_memstats(self) -> (int, int):
|
def fetch_memstats(self) -> (int, int):
|
||||||
|
@ -44,6 +44,9 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
|||||||
mem_use = _col_tensor_mem_usage(t)
|
mem_use = _col_tensor_mem_usage(t)
|
||||||
self._cuda_usage -= mem_use
|
self._cuda_usage -= mem_use
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._cuda_usage = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cpu_usage(self):
|
def cpu_usage(self):
|
||||||
return self._cpu_usage
|
return self._cpu_usage
|
||||||
|
@ -9,6 +9,28 @@ import torch
|
|||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def colo_cuda_memory_used(device: Optional[torch.device] = None) -> int:
|
||||||
|
"""
|
||||||
|
Get the free memory info of device.
|
||||||
|
:param device: a torch device instance or None
|
||||||
|
:type device: Optional[torch.device]
|
||||||
|
:return: current memory usage, sized by Byte
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
if device:
|
||||||
|
assert device.type == 'cuda'
|
||||||
|
else:
|
||||||
|
device = torch.device(f'cuda:{get_current_device()}')
|
||||||
|
|
||||||
|
ret: int = torch.cuda.memory_allocated(device)
|
||||||
|
# get the peak memory to report correct data, so reset the counter for the next call
|
||||||
|
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||||
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_GB(val, decimal=2):
|
def bytes_to_GB(val, decimal=2):
|
||||||
|
@ -3,7 +3,7 @@ from colossalai.utils import get_current_device
|
|||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union, Optional
|
||||||
|
|
||||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||||
|
|
||||||
|
@ -6,16 +6,14 @@ from colossalai.context.parallel_mode import ParallelMode
|
|||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
|
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||||
|
|
||||||
# Inserts _post_init_method at the end of init method
|
|
||||||
|
|
||||||
|
|
||||||
# for all sub classes of torch.nn.Module
|
|
||||||
class InsertPostInitMethodToModuleSubClasses(object):
|
class InsertPostInitMethodToModuleSubClasses(object):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -144,8 +142,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
|
|
||||||
del self.initialized_param_list
|
del self.initialized_param_list
|
||||||
GLOBAL_MODEL_DATA_TRACER.close()
|
GLOBAL_MODEL_DATA_TRACER.close()
|
||||||
cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
|
model_data_cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
|
||||||
self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0])
|
self.logger.info(f"Existing ZeRO Context: Model Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
|
||||||
|
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
|
||||||
|
self.logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
|
||||||
|
self.logger.info(f"Model Number Parameter {self.model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
|
||||||
|
|
||||||
def _post_init_method(self, module: torch.nn.Module):
|
def _post_init_method(self, module: torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -178,8 +179,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
|
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
|
||||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||||
# if param.col_attr.grad and self.shard_grad:
|
# if param.col_attr.grad and self.shard_grad:
|
||||||
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
||||||
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||||
|
@ -23,6 +23,9 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||||||
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||||
if t.is_sharded:
|
if t.is_sharded:
|
||||||
return
|
return
|
||||||
|
if t.payload.device.type == 'cuda':
|
||||||
|
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||||
|
f" but current cuda device is {get_current_device()}"
|
||||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||||
t.reset_payload(sharded_payload)
|
t.reset_payload(sharded_payload)
|
||||||
t.is_sharded = True
|
t.is_sharded = True
|
||||||
|
@ -19,17 +19,24 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||||||
from common import CONFIG
|
from common import CONFIG
|
||||||
|
|
||||||
|
|
||||||
@parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
|
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
def run_model_test(init_device, shard_strategy_class):
|
def run_model_test(init_device_type, shard_strategy_class):
|
||||||
for get_components_func in non_distributed_component_funcs:
|
for get_components_func in non_distributed_component_funcs:
|
||||||
model_builder, _, _, _, _ = get_components_func()
|
model_builder, _, _, _, _ = get_components_func()
|
||||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||||
|
if init_device_type == 'cuda':
|
||||||
|
init_device = torch.device(f"cuda:{get_current_device()}")
|
||||||
|
elif init_device_type == 'cpu':
|
||||||
|
init_device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
continue
|
||||||
with ZeroInitContext(convert_fp16=True,
|
with ZeroInitContext(convert_fp16=True,
|
||||||
target_device=init_device,
|
target_device=init_device,
|
||||||
shard_strategy=shard_strategy_class(),
|
shard_strategy=shard_strategy_class(),
|
||||||
shard_param=True,
|
shard_param=True,
|
||||||
model_numel_tensor=model_numel_tensor):
|
model_numel_tensor=model_numel_tensor,
|
||||||
|
rm_torch_payload_on_the_fly=False):
|
||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
@ -38,11 +45,9 @@ def run_model_test(init_device, shard_strategy_class):
|
|||||||
assert param.col_attr.sharded_data_tensor.is_sharded
|
assert param.col_attr.sharded_data_tensor.is_sharded
|
||||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||||
|
if init_device.type == 'cuda':
|
||||||
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||||
print(f'numel {model_numel_tensor}')
|
GLOBAL_MODEL_DATA_TRACER.clear()
|
||||||
if init_device.type == 'cuda':
|
|
||||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
Loading…
Reference in New Issue
Block a user