mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[autoparallel] add numerical test for node strategies (#1760)
* [autoparallel] add numerical test for node strategies * polish code * polish code
This commit is contained in:
@@ -6,7 +6,6 @@ from functools import reduce
|
||||
import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
|
||||
|
||||
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||
|
||||
@@ -23,7 +22,7 @@ class _DimSpec:
|
||||
This class is used internally in ShardingSpec.
|
||||
|
||||
Argument:
|
||||
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
|
||||
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
|
||||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||
'''
|
||||
|
||||
@@ -62,7 +61,7 @@ class _DimSpec:
|
||||
|
||||
def build_difference_2d_dict(self):
|
||||
'''
|
||||
Build a difference maping for 2D device mesh case. It will be used to
|
||||
Build a difference maping for 2D device mesh case. It will be used to
|
||||
compute the difference between DimSpec pairs.
|
||||
'''
|
||||
|
||||
@@ -159,9 +158,9 @@ class ShardingNotDivisibleError(ShardingSpecException):
|
||||
class ShardingSpec:
|
||||
'''
|
||||
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
|
||||
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
|
||||
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
|
||||
[R, R, S0, S1].
|
||||
|
||||
|
||||
Argument:
|
||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
||||
entire_shape(torch.Size): The entire shape of tensor before sharded.
|
||||
@@ -260,10 +259,10 @@ class ShardingSpec:
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
|
||||
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
|
||||
|
||||
|
||||
Output:
|
||||
25
|
||||
|
||||
|
||||
Argument:
|
||||
other(ShardingSpec): The ShardingSpec to compared with.
|
||||
|
||||
|
Reference in New Issue
Block a user