[autoparallel] add numerical test for node strategies (#1760)

* [autoparallel] add numerical test for node strategies

* polish code

* polish code
This commit is contained in:
YuliangLiu0306
2022-10-27 10:42:54 +08:00
committed by GitHub
parent 25952b67d7
commit b4cc59b61e
10 changed files with 283 additions and 60 deletions

View File

@@ -28,6 +28,15 @@ class ShapeConsistencyOptions:
pass
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
shape_consistency_manager = ShapeConsistencyManager()
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
with torch.no_grad():
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
global_sharding_spec)
return global_tensor
def set_shape_consistency_options(options: ShapeConsistencyOptions):
"""
Configure the shape consistency manager via function call.