zero init ctx receives a dp process group (#471)

This commit is contained in:
ver217 2022-03-21 11:18:55 +08:00 committed by GitHub
parent 7e30068a22
commit 3cb3fc275e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,11 +1,15 @@
import functools import functools
from typing import Optional
import torch import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
# Inserts _post_init_method at the end of init method # Inserts _post_init_method at the end of init method
@ -103,8 +107,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
shard_strategy: BaseShardStrategy, shard_strategy: BaseShardStrategy,
shard_param: bool = False, shard_param: bool = False,
shard_grad: bool = False, shard_grad: bool = False,
rm_torch_payload_on_the_fly=False, rm_torch_payload_on_the_fly: bool = False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int)): model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int),
dp_process_group: Optional[ProcessGroup] = None):
super().__init__() super().__init__()
self.convert_fp16 = convert_fp16 self.convert_fp16 = convert_fp16
self.target_device = target_device self.target_device = target_device
@ -115,6 +120,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.rm_torch_payload_on_the_fly = False self.rm_torch_payload_on_the_fly = False
self.initialized_param_list = [] self.initialized_param_list = []
self.model_numel_tensor = model_numel_tensor self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
def _post_context_exec(self): def _post_context_exec(self):
"""The callback function when the context exits. """The callback function when the context exits.
@ -154,10 +160,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.initialized_param_list.append(param) self.initialized_param_list.append(param)
if self.shard_param: if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor]) self.shard_strategy.shard([param.col_attr._data_sharded_tensor], self.dp_process_group)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload) GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
# if param.col_attr.grad and self.shard_grad: # if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor]) # self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) # GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
# We must cast buffers # We must cast buffers
# If we use BN, buffers may be on CPU and Float # If we use BN, buffers may be on CPU and Float