mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[autoparallel] add numerical test for node strategies (#1760)
* [autoparallel] add numerical test for node strategies * polish code * polish code
This commit is contained in:
@@ -28,6 +28,15 @@ class ShapeConsistencyOptions:
|
||||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
|
||||
with torch.no_grad():
|
||||
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
|
||||
global_sharding_spec)
|
||||
return global_tensor
|
||||
|
||||
|
||||
def set_shape_consistency_options(options: ShapeConsistencyOptions):
|
||||
"""
|
||||
Configure the shape consistency manager via function call.
|
||||
|
@@ -6,7 +6,6 @@ from functools import reduce
|
||||
import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
|
||||
|
||||
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||
|
||||
@@ -23,7 +22,7 @@ class _DimSpec:
|
||||
This class is used internally in ShardingSpec.
|
||||
|
||||
Argument:
|
||||
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
|
||||
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
|
||||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||
'''
|
||||
|
||||
@@ -62,7 +61,7 @@ class _DimSpec:
|
||||
|
||||
def build_difference_2d_dict(self):
|
||||
'''
|
||||
Build a difference maping for 2D device mesh case. It will be used to
|
||||
Build a difference maping for 2D device mesh case. It will be used to
|
||||
compute the difference between DimSpec pairs.
|
||||
'''
|
||||
|
||||
@@ -159,9 +158,9 @@ class ShardingNotDivisibleError(ShardingSpecException):
|
||||
class ShardingSpec:
|
||||
'''
|
||||
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
|
||||
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
|
||||
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
|
||||
[R, R, S0, S1].
|
||||
|
||||
|
||||
Argument:
|
||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
||||
entire_shape(torch.Size): The entire shape of tensor before sharded.
|
||||
@@ -260,10 +259,10 @@ class ShardingSpec:
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
|
||||
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
|
||||
|
||||
|
||||
Output:
|
||||
25
|
||||
|
||||
|
||||
Argument:
|
||||
other(ShardingSpec): The ShardingSpec to compared with.
|
||||
|
||||
|
Reference in New Issue
Block a user