[shardformer] supported fused qkv checkpoint (#4073)

This commit is contained in:
Frank Lee
2023-06-23 16:07:09 +08:00
parent 0803a61412
commit 70c58cfd4f
10 changed files with 420 additions and 88 deletions

View File

@@ -12,11 +12,14 @@ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
get_device_mesh,
get_sharding_spec,
is_customized_distributed_tensor,
is_distributed_tensor,
sharded_tensor_to_param,
to_global,
to_global_for_customized_distributed_tensor,
)
__all__ = ['ParallelModule']
@@ -54,9 +57,10 @@ class ParallelModule(nn.Module, ABC):
for name, param in self._parameters.items():
if param is not None:
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
destination[prefix + name] = to_global(param_)
elif is_customized_distributed_tensor(param_):
destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
else:
destination[prefix + name] = param_
@@ -124,6 +128,8 @@ class ParallelModule(nn.Module, ABC):
sharding_spec = get_sharding_spec(param)
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
input_param = sharded_tensor_to_param(sharded_tensor)
elif is_customized_distributed_tensor(param):
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks