mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[autoparallel] add numerical test for node strategies (#1760)
* [autoparallel] add numerical test for node strategies * polish code * polish code
This commit is contained in:
@@ -28,6 +28,15 @@ class ShapeConsistencyOptions:
|
||||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
|
||||
with torch.no_grad():
|
||||
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
|
||||
global_sharding_spec)
|
||||
return global_tensor
|
||||
|
||||
|
||||
def set_shape_consistency_options(options: ShapeConsistencyOptions):
|
||||
"""
|
||||
Configure the shape consistency manager via function call.
|
||||
|
Reference in New Issue
Block a user