mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[tensor] support runtime ShardingSpec apply (#1453)
* [tensor] support runtime ShardingSpec apply * polish code * polish code
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from functools import reduce
|
||||
import operator
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class DeviceMesh:
|
||||
@@ -18,9 +19,13 @@ class DeviceMesh:
|
||||
communication cost (default: None)
|
||||
mesh_beta (List[float], optional): coefficients used for computing
|
||||
communication cost (default: None)
|
||||
init_process_group (bool, optional): initialize logical process group
|
||||
during initializing the DeviceMesh instance if the init_process_group set to True.
|
||||
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
|
||||
(default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
|
||||
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None, init_process_group=False):
|
||||
self.physical_mesh_id = physical_mesh_id
|
||||
self.mesh_shape = mesh_shape
|
||||
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
|
||||
@@ -34,6 +39,8 @@ class DeviceMesh:
|
||||
mesh_beta = [1] * len(self.mesh_shape)
|
||||
self.mesh_alpha = tuple(mesh_alpha)
|
||||
self.mesh_beta = tuple(mesh_beta)
|
||||
if init_process_group:
|
||||
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
@@ -57,6 +64,28 @@ class DeviceMesh:
|
||||
else:
|
||||
self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
|
||||
|
||||
def create_process_groups_for_logical_mesh(self):
|
||||
'''
|
||||
This method is used to initialize the logical process groups which will be used in communications
|
||||
among logical device mesh.
|
||||
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
|
||||
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
|
||||
'''
|
||||
process_groups_dict = {}
|
||||
check_duplicate_list = []
|
||||
global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
|
||||
for global_rank in global_rank_flatten_list:
|
||||
process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
|
||||
for axis, process_group in process_groups.items():
|
||||
if axis not in process_groups_dict:
|
||||
process_groups_dict[axis] = []
|
||||
if process_group not in check_duplicate_list:
|
||||
check_duplicate_list.append(process_group)
|
||||
process_group_handler = dist.new_group(process_group)
|
||||
process_groups_dict[axis].append((process_group, process_group_handler))
|
||||
|
||||
return process_groups_dict
|
||||
|
||||
def global_rank_to_logical_rank(self, rank):
|
||||
return self.convert_map[rank]
|
||||
|
||||
|
Reference in New Issue
Block a user