mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[shardformer] supported fused qkv checkpoint (#4073)
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
from .api import (
|
||||
compute_global_numel,
|
||||
customized_distributed_tensor_to_param,
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
get_device_mesh,
|
||||
get_global_shape,
|
||||
get_layout,
|
||||
get_sharding_spec,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
is_sharded,
|
||||
redistribute,
|
||||
@@ -12,6 +15,7 @@ from .api import (
|
||||
shard_rowwise,
|
||||
sharded_tensor_to_param,
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from .layout import Layout
|
||||
from .sharding_spec import ShardingSpec
|
||||
@@ -19,6 +23,6 @@ from .sharding_spec import ShardingSpec
|
||||
__all__ = [
|
||||
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
|
||||
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
|
||||
'redistribute', 'get_layout'
|
||||
'Layout', 'ShardingSpec'
|
||||
'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization',
|
||||
'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec'
|
||||
]
|
||||
|
Reference in New Issue
Block a user