[tensor] support runtime ShardingSpec apply (#1453)

* [tensor] support runtime ShardingSpec apply

* polish code

* polish code
This commit is contained in:
YuliangLiu0306
2022-08-19 13:39:51 +08:00
committed by GitHub
parent 177d3f5718
commit b73fb7a077
5 changed files with 485 additions and 11 deletions

View File

@@ -1,6 +1,7 @@
from functools import reduce
import operator
import torch
import torch.distributed as dist
class DeviceMesh:
@@ -18,9 +19,13 @@ class DeviceMesh:
communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing
communication cost (default: None)
init_process_group (bool, optional): initialize logical process group
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
"""
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None, init_process_group=False):
self.physical_mesh_id = physical_mesh_id
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
@@ -34,6 +39,8 @@ class DeviceMesh:
mesh_beta = [1] * len(self.mesh_shape)
self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
if init_process_group:
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
@property
def shape(self):
@@ -57,6 +64,28 @@ class DeviceMesh:
else:
self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
process_groups_dict = {}
check_duplicate_list = []
global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
for axis, process_group in process_groups.items():
if axis not in process_groups_dict:
process_groups_dict[axis] = []
if process_group not in check_duplicate_list:
check_duplicate_list.append(process_group)
process_group_handler = dist.new_group(process_group)
process_groups_dict[axis].append((process_group, process_group_handler))
return process_groups_dict
def global_rank_to_logical_rank(self, rank):
return self.convert_map[rank]