[zero] fix init device bug in zero init context unittest (#516)

This commit is contained in:
Jiarui Fang
2022-03-25 12:24:18 +08:00
committed by GitHub
parent a513164379
commit 0bebda6ea5
8 changed files with 55 additions and 37 deletions

View File

@@ -6,16 +6,14 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from colossalai.logging import get_dist_logger, disable_existing_loggers
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object):
def __init__(self):
@@ -144,8 +142,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
del self.initialized_param_list
GLOBAL_MODEL_DATA_TRACER.close()
cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0])
model_data_cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
self.logger.info(f"Existing ZeRO Context: Model Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
self.logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
self.logger.info(f"Model Number Parameter {self.model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
def _post_init_method(self, module: torch.nn.Module):
"""
@@ -178,8 +179,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param:
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
# if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)

View File

@@ -23,6 +23,9 @@ class TensorShardStrategy(BaseShardStrategy):
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
if t.is_sharded:
return
if t.payload.device.type == 'cuda':
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
f" but current cuda device is {get_current_device()}"
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.reset_payload(sharded_payload)
t.is_sharded = True