mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish code
This commit is contained in:
@@ -28,18 +28,6 @@ class LayoutConverterOptions:
|
||||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
|
||||
layout_converter = LayoutConverter()
|
||||
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
|
||||
global_layout = Layout(device_mesh=layout.device_mesh,
|
||||
device_type=layout.device_type,
|
||||
sharding_spec=global_sharding_spec,
|
||||
entire_shape=layout.entire_shape)
|
||||
with torch.no_grad():
|
||||
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
|
||||
return global_tensor
|
||||
|
||||
|
||||
def set_layout_converting_options(options: LayoutConverterOptions):
|
||||
"""
|
||||
Configure the shape consistency manager via function call.
|
||||
@@ -553,4 +541,5 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
|
||||
for comm_spec in comm_action_sequence:
|
||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||
tensor.dist_layout = target_layout
|
||||
return tensor
|
||||
|
Reference in New Issue
Block a user