mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[zero] fix init device bug in zero init context unittest (#516)
This commit is contained in:
@@ -6,16 +6,14 @@ from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from torch.distributed import ProcessGroup
|
||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||
|
||||
# Inserts _post_init_method at the end of init method
|
||||
|
||||
|
||||
# for all sub classes of torch.nn.Module
|
||||
class InsertPostInitMethodToModuleSubClasses(object):
|
||||
|
||||
def __init__(self):
|
||||
@@ -144,8 +142,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
del self.initialized_param_list
|
||||
GLOBAL_MODEL_DATA_TRACER.close()
|
||||
cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
|
||||
self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0])
|
||||
model_data_cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
|
||||
self.logger.info(f"Existing ZeRO Context: Model Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
|
||||
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
|
||||
self.logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
|
||||
self.logger.info(f"Model Number Parameter {self.model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module):
|
||||
"""
|
||||
@@ -178,8 +179,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
# if param.col_attr.grad and self.shard_grad:
|
||||
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
||||
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
|
@@ -23,6 +23,9 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
if t.is_sharded:
|
||||
return
|
||||
if t.payload.device.type == 'cuda':
|
||||
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||
f" but current cuda device is {get_current_device()}"
|
||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||
t.reset_payload(sharded_payload)
|
||||
t.is_sharded = True
|
||||
|
Reference in New Issue
Block a user