[DTensor] refactor dtensor with new components (#3089)

* [DTensor] refactor dtensor with new components

* polish
This commit is contained in:
YuliangLiu0306
2023-03-14 16:25:47 +08:00
committed by GitHub
parent ed8f60b93b
commit 2eca4cd376
3 changed files with 20 additions and 41 deletions

View File

@@ -22,21 +22,21 @@ __all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_o
@dataclass
class LayoutConverterOptions:
"""
LayoutConverterOptions is a dataclass which specifies the preferences for shape consistency.
LayoutConverterOptions is a dataclass which specifies the preferences for layout converting.
"""
# TODO: layout converter option is not implemented yet
pass
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
shape_consistency_manager = LayoutConverter()
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
global_layout = Layout(device_mesh=layout.device_mesh,
device_type=layout.device_type,
sharding_spec=global_sharding_spec,
entire_shape=layout.entire_shape)
with torch.no_grad():
global_tensor = shape_consistency_manager.apply(distributed_tensor, layout, global_layout)
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
return global_tensor