[autoparallel] add numerical test for node strategies (#1760)

* [autoparallel] add numerical test for node strategies

* polish code

* polish code
This commit is contained in:
YuliangLiu0306
2022-10-27 10:42:54 +08:00
committed by GitHub
parent 25952b67d7
commit b4cc59b61e
10 changed files with 283 additions and 60 deletions

View File

@@ -1,5 +1,6 @@
from functools import reduce
import operator
from functools import reduce
import torch
import torch.distributed as dist
@@ -11,7 +12,7 @@ class DeviceMesh:
can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
own latency and bandwidth. We use alpha-beta model to model the
communication cost.
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view.
@@ -64,6 +65,18 @@ class DeviceMesh:
def logical_mesh_id(self):
return self._logical_mesh_id
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
setattr(result, k, v)
return result
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
@@ -90,7 +103,7 @@ class DeviceMesh:
def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''