mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[shardformer] supported fused qkv checkpoint (#4073)
This commit is contained in:
@@ -12,11 +12,14 @@ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
|
||||
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
get_device_mesh,
|
||||
get_sharding_spec,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
sharded_tensor_to_param,
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
|
||||
__all__ = ['ParallelModule']
|
||||
@@ -54,9 +57,10 @@ class ParallelModule(nn.Module, ABC):
|
||||
for name, param in self._parameters.items():
|
||||
if param is not None:
|
||||
param_ = param if keep_vars else param.detach()
|
||||
|
||||
if is_distributed_tensor(param_):
|
||||
destination[prefix + name] = to_global(param_)
|
||||
elif is_customized_distributed_tensor(param_):
|
||||
destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
|
||||
else:
|
||||
destination[prefix + name] = param_
|
||||
|
||||
@@ -124,6 +128,8 @@ class ParallelModule(nn.Module, ABC):
|
||||
sharding_spec = get_sharding_spec(param)
|
||||
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
|
||||
input_param = sharded_tensor_to_param(sharded_tensor)
|
||||
elif is_customized_distributed_tensor(param):
|
||||
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
|
||||
|
||||
# This is used to avoid copying uninitialized parameters into
|
||||
# non-lazy modules, since they dont have the hook to do the checks
|
||||
|
Reference in New Issue
Block a user