mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -39,7 +39,7 @@ class ZeroContextConfig:
|
||||
assert self.is_replicated, "Non-replicated parameters can't be sharded."
|
||||
|
||||
if self.is_replicated and not self.shard_param:
|
||||
assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda."
|
||||
assert self.target_device.type == "cuda", "Replicated no-shard parameters should be located in cuda."
|
||||
|
||||
|
||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
@@ -59,15 +59,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
target_device: torch.device,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
seed: int = 2**10 - 1,
|
||||
shard_param: bool = False,
|
||||
default_dtype: Optional[torch.dtype] = None,
|
||||
bf16: bool = False,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_device: torch.device,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
seed: int = 2**10 - 1,
|
||||
shard_param: bool = False,
|
||||
default_dtype: Optional[torch.dtype] = None,
|
||||
bf16: bool = False,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long),
|
||||
):
|
||||
super().__init__(default_dtype=default_dtype)
|
||||
self.shard_strategy = shard_strategy
|
||||
self.param_list = []
|
||||
@@ -103,7 +104,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
assert isinstance(tensor, nn.Parameter), "Sharded tensor initialization is only allowed for parameters"
|
||||
|
||||
# get correct shape of input tensor
|
||||
if not hasattr(tensor, 'colo_attr') or not tensor.colo_attr.param_is_sharded:
|
||||
if not hasattr(tensor, "colo_attr") or not tensor.colo_attr.param_is_sharded:
|
||||
tensor_shape = tensor.shape
|
||||
else:
|
||||
tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape
|
||||
@@ -137,13 +138,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
self.module_load_from_state_dict = nn.Module._load_from_state_dict
|
||||
shard_strategy = self.shard_strategy if self.config.shard_param else None
|
||||
nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict,
|
||||
shard_strategy=shard_strategy)
|
||||
nn.Module._load_from_state_dict = functools.partialmethod(
|
||||
ShardedModelV2._colo_load_from_state_dict, shard_strategy=shard_strategy
|
||||
)
|
||||
self.module_state_dict = nn.Module.state_dict
|
||||
nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict,
|
||||
shard_strategy=shard_strategy,
|
||||
state_dict_func=self.module_state_dict,
|
||||
process_group=self.dp_process_group)
|
||||
nn.Module.state_dict = functools.partialmethod(
|
||||
ShardedModelV2._colo_state_dict,
|
||||
shard_strategy=shard_strategy,
|
||||
state_dict_func=self.module_state_dict,
|
||||
process_group=self.dp_process_group,
|
||||
)
|
||||
|
||||
# reserve rng states
|
||||
self.cpu_rng_state = torch.get_rng_state()
|
||||
@@ -152,16 +156,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
# set new seed for initialization, since we initialize sharded tensor separately
|
||||
# we don't want all processes have the same seed
|
||||
# otherwise all sharded tensors are same after init
|
||||
offset = self.seed + 1 # we want to have more 1 in binary format seed
|
||||
offset = self.seed + 1 # we want to have more 1 in binary format seed
|
||||
torch.manual_seed(self.seed + offset * dist.get_rank())
|
||||
|
||||
def _post_context_exec(self):
|
||||
"""The callback function when exiting context.
|
||||
"""
|
||||
"""The callback function when exiting context."""
|
||||
# broadcast replicated no-shard parameters
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
||||
for param in self.param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
assert hasattr(param, "colo_attr")
|
||||
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
|
||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||
param.colo_attr.set_data_none()
|
||||
@@ -193,7 +196,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
if hasattr(param, 'colo_attr'):
|
||||
if hasattr(param, "colo_attr"):
|
||||
continue
|
||||
|
||||
self.param_numel[param] = param.numel()
|
||||
@@ -216,7 +219,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
param.data = param.colo_attr.data_payload # set param.data to payload
|
||||
param.data = param.colo_attr.data_payload # set param.data to payload
|
||||
|
||||
# mark whether the param is replicated
|
||||
param.colo_attr.is_replicated = self.is_replicated
|
||||
@@ -251,15 +254,13 @@ class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
|
||||
|
||||
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
||||
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
is_replicated=is_replicated,
|
||||
shard_param=False)
|
||||
return ZeroContextMgr().hijack_context_config(
|
||||
target_device=torch.device("cuda", torch.cuda.current_device()), is_replicated=is_replicated, shard_param=False
|
||||
)
|
||||
|
||||
|
||||
def no_shard_zero_decrator(is_replicated: bool = True):
|
||||
|
||||
def _wrapper(init_func):
|
||||
|
||||
def _no_shard(*args, **kwargs):
|
||||
with no_shard_zero_context(is_replicated):
|
||||
ret = init_func(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user