[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
This commit is contained in:
Frank Lee
2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@@ -28,18 +28,6 @@ class LayoutConverterOptions:
pass
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
global_layout = Layout(device_mesh=layout.device_mesh,
device_type=layout.device_type,
sharding_spec=global_sharding_spec,
entire_shape=layout.entire_shape)
with torch.no_grad():
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
return global_tensor
def set_layout_converting_options(options: LayoutConverterOptions):
"""
Configure the shape consistency manager via function call.
@@ -553,4 +541,5 @@ class LayoutConverter(metaclass=SingletonMeta):
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
for comm_spec in comm_action_sequence:
tensor = comm_spec.covert_spec_to_action(tensor)
tensor.dist_layout = target_layout
return tensor