mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-03 20:59:49 +00:00
zero init ctx receives a dp process group (#471)
This commit is contained in:
parent
7e30068a22
commit
3cb3fc275e
@ -1,11 +1,15 @@
|
|||||||
import functools
|
import functools
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
# Inserts _post_init_method at the end of init method
|
# Inserts _post_init_method at the end of init method
|
||||||
|
|
||||||
@ -103,8 +107,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
shard_strategy: BaseShardStrategy,
|
shard_strategy: BaseShardStrategy,
|
||||||
shard_param: bool = False,
|
shard_param: bool = False,
|
||||||
shard_grad: bool = False,
|
shard_grad: bool = False,
|
||||||
rm_torch_payload_on_the_fly=False,
|
rm_torch_payload_on_the_fly: bool = False,
|
||||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int)):
|
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int),
|
||||||
|
dp_process_group: Optional[ProcessGroup] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.convert_fp16 = convert_fp16
|
self.convert_fp16 = convert_fp16
|
||||||
self.target_device = target_device
|
self.target_device = target_device
|
||||||
@ -115,6 +120,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
self.rm_torch_payload_on_the_fly = False
|
self.rm_torch_payload_on_the_fly = False
|
||||||
self.initialized_param_list = []
|
self.initialized_param_list = []
|
||||||
self.model_numel_tensor = model_numel_tensor
|
self.model_numel_tensor = model_numel_tensor
|
||||||
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||||
|
|
||||||
def _post_context_exec(self):
|
def _post_context_exec(self):
|
||||||
"""The callback function when the context exits.
|
"""The callback function when the context exits.
|
||||||
@ -154,10 +160,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
self.initialized_param_list.append(param)
|
self.initialized_param_list.append(param)
|
||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
self.shard_strategy.shard([param.col_attr._data_sharded_tensor], self.dp_process_group)
|
||||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||||
# if param.col_attr.grad and self.shard_grad:
|
# if param.col_attr.grad and self.shard_grad:
|
||||||
# self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
||||||
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||||
# We must cast buffers
|
# We must cast buffers
|
||||||
# If we use BN, buffers may be on CPU and Float
|
# If we use BN, buffers may be on CPU and Float
|
||||||
|
Loading…
Reference in New Issue
Block a user