[zero] fix init device bug in zero init context unittest (#516)

This commit is contained in:
Jiarui Fang 2022-03-25 12:24:18 +08:00 committed by GitHub
parent a513164379
commit 0bebda6ea5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 55 additions and 37 deletions

View File

@ -2,26 +2,10 @@ from concurrent.futures import ThreadPoolExecutor
from time import sleep, time from time import sleep, time
import pickle import pickle
from colossalai.utils import get_current_device
import torch import torch
from colossalai.utils import get_current_device
def get_cuda_memory_used(device: torch.device) -> int: from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
"""
Get the free memory info of device.
:param device: device id
:type device: torch.device
:return: current memory usage, sized by MB
:rtype: int
"""
assert device.type == 'cuda'
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
class AsyncMemoryMonitor: class AsyncMemoryMonitor:
@ -97,7 +81,7 @@ class AsyncMemoryMonitor:
while self.keep_measuring: while self.keep_measuring:
max_usage = max( max_usage = max(
max_usage, max_usage,
get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')), colo_cuda_memory_used(),
) )
sleep(self.interval) sleep(self.interval)
return max_usage return max_usage

View File

@ -1,5 +1,5 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from .async_memtracer import get_cuda_memory_used from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
import torch import torch
@ -55,7 +55,7 @@ class MemStatsCollector:
sampling_cnt = self._sampling_cnter.sampling_cnt sampling_cnt = self._sampling_cnter.sampling_cnt
assert sampling_cnt == len(self._overall_cuda) assert sampling_cnt == len(self._overall_cuda)
self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage) self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}'))) self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
self._sampling_cnter.advance() self._sampling_cnter.advance()
def fetch_memstats(self) -> (int, int): def fetch_memstats(self) -> (int, int):

View File

@ -44,6 +44,9 @@ class ModelDataTracer(metaclass=SingletonMeta):
mem_use = _col_tensor_mem_usage(t) mem_use = _col_tensor_mem_usage(t)
self._cuda_usage -= mem_use self._cuda_usage -= mem_use
def clear(self) -> None:
self._cuda_usage = 0
@property @property
def cpu_usage(self): def cpu_usage(self):
return self._cpu_usage return self._cpu_usage

View File

@ -9,6 +9,28 @@ import torch
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.cuda import get_current_device
from typing import Optional
def colo_cuda_memory_used(device: Optional[torch.device] = None) -> int:
"""
Get the free memory info of device.
:param device: a torch device instance or None
:type device: Optional[torch.device]
:return: current memory usage, sized by Byte
:rtype: int
"""
if device:
assert device.type == 'cuda'
else:
device = torch.device(f'cuda:{get_current_device()}')
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
def bytes_to_GB(val, decimal=2): def bytes_to_GB(val, decimal=2):

View File

@ -3,7 +3,7 @@ from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Union from typing import Union, Optional
_GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CUDA_MEM_FRACTION = 1.0

View File

@ -6,16 +6,14 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.logging import get_dist_logger, disable_existing_loggers from colossalai.logging import get_dist_logger, disable_existing_loggers
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object): class InsertPostInitMethodToModuleSubClasses(object):
def __init__(self): def __init__(self):
@ -144,8 +142,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
del self.initialized_param_list del self.initialized_param_list
GLOBAL_MODEL_DATA_TRACER.close() GLOBAL_MODEL_DATA_TRACER.close()
cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6 model_data_cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0]) self.logger.info(f"Existing ZeRO Context: Model Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
self.logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
self.logger.info(f"Model Number Parameter {self.model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
def _post_init_method(self, module: torch.nn.Module): def _post_init_method(self, module: torch.nn.Module):
""" """

View File

@ -23,6 +23,9 @@ class TensorShardStrategy(BaseShardStrategy):
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
if t.is_sharded: if t.is_sharded:
return return
if t.payload.device.type == 'cuda':
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
f" but current cuda device is {get_current_device()}"
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.reset_payload(sharded_payload) t.reset_payload(sharded_payload)
t.is_sharded = True t.is_sharded = True

View File

@ -19,17 +19,24 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
@parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')]) @parameterize("init_device_type", ['cpu', 'cuda'])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(init_device, shard_strategy_class): def run_model_test(init_device_type, shard_strategy_class):
for get_components_func in non_distributed_component_funcs: for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func() model_builder, _, _, _, _ = get_components_func()
model_numel_tensor = torch.zeros(1, dtype=torch.int) model_numel_tensor = torch.zeros(1, dtype=torch.int)
if init_device_type == 'cuda':
init_device = torch.device(f"cuda:{get_current_device()}")
elif init_device_type == 'cpu':
init_device = torch.device("cpu")
else:
continue
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(convert_fp16=True,
target_device=init_device, target_device=init_device,
shard_strategy=shard_strategy_class(), shard_strategy=shard_strategy_class(),
shard_param=True, shard_param=True,
model_numel_tensor=model_numel_tensor): model_numel_tensor=model_numel_tensor,
rm_torch_payload_on_the_fly=False):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
for param in model.parameters(): for param in model.parameters():
@ -38,11 +45,9 @@ def run_model_test(init_device, shard_strategy_class):
assert param.col_attr.sharded_data_tensor.is_sharded assert param.col_attr.sharded_data_tensor.is_sharded
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \ assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
print(f'numel {model_numel_tensor}')
if init_device.type == 'cuda': if init_device.type == 'cuda':
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
GLOBAL_MODEL_DATA_TRACER.clear()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):