[shardformer] supported fused qkv checkpoint (#4073)

This commit is contained in:
Frank Lee
2023-06-23 16:07:09 +08:00
parent 0803a61412
commit 70c58cfd4f
10 changed files with 420 additions and 88 deletions

View File

@@ -1,10 +1,13 @@
from .api import (
compute_global_numel,
customized_distributed_tensor_to_param,
distribute_tensor,
distribute_tensor_with_customization,
get_device_mesh,
get_global_shape,
get_layout,
get_sharding_spec,
is_customized_distributed_tensor,
is_distributed_tensor,
is_sharded,
redistribute,
@@ -12,6 +15,7 @@ from .api import (
shard_rowwise,
sharded_tensor_to_param,
to_global,
to_global_for_customized_distributed_tensor,
)
from .layout import Layout
from .sharding_spec import ShardingSpec
@@ -19,6 +23,6 @@ from .sharding_spec import ShardingSpec
__all__ = [
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
'redistribute', 'get_layout'
'Layout', 'ShardingSpec'
'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization',
'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec'
]