mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[DTensor] refactor dtensor with new components (#3089)
* [DTensor] refactor dtensor with new components * polish
This commit is contained in:
@@ -22,21 +22,21 @@ __all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_o
|
||||
@dataclass
|
||||
class LayoutConverterOptions:
|
||||
"""
|
||||
LayoutConverterOptions is a dataclass which specifies the preferences for shape consistency.
|
||||
LayoutConverterOptions is a dataclass which specifies the preferences for layout converting.
|
||||
"""
|
||||
# TODO: layout converter option is not implemented yet
|
||||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
|
||||
shape_consistency_manager = LayoutConverter()
|
||||
layout_converter = LayoutConverter()
|
||||
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
|
||||
global_layout = Layout(device_mesh=layout.device_mesh,
|
||||
device_type=layout.device_type,
|
||||
sharding_spec=global_sharding_spec,
|
||||
entire_shape=layout.entire_shape)
|
||||
with torch.no_grad():
|
||||
global_tensor = shape_consistency_manager.apply(distributed_tensor, layout, global_layout)
|
||||
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
|
||||
return global_tensor
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user