[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
This commit is contained in:
Frank Lee
2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@@ -0,0 +1,24 @@
from .api import (
compute_global_numel,
distribute_tensor,
get_device_mesh,
get_global_shape,
get_layout,
get_sharding_spec,
is_distributed_tensor,
is_sharded,
redistribute,
shard_colwise,
shard_rowwise,
sharded_tensor_to_param,
to_global,
)
from .layout import Layout
from .sharding_spec import ShardingSpec
__all__ = [
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
'redistribute', 'get_layout'
'Layout', 'ShardingSpec'
]