mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[autoparallel] mix gather (#1977)
* Add mix-gather * Add comments * Add comments * Polish comments * Change the global rank assumption * Add tests * Add two-step tests * Fix 10 and 01 * Skip test becasue the number of GPUs
This commit is contained in:
@@ -52,6 +52,9 @@ class DeviceMesh:
|
||||
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
|
||||
if self.need_flatten:
|
||||
self.flatten_device_mesh = self.flatten()
|
||||
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
|
||||
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
self.mesh_beta)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
@@ -199,3 +202,38 @@ class DeviceMesh:
|
||||
penalty_factor = num_devices / 2.0
|
||||
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
|
||||
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
|
||||
|
||||
|
||||
class FlattenDeviceMesh(DeviceMesh):
|
||||
|
||||
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
|
||||
super().__init__(physical_mesh_id,
|
||||
mesh_shape,
|
||||
mesh_alpha,
|
||||
mesh_beta,
|
||||
init_process_group=False,
|
||||
need_flatten=False)
|
||||
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
|
||||
self.mesh_alpha = max(self.mesh_alpha)
|
||||
self.mesh_beta = min(self.mesh_beta)
|
||||
# Different from original process_groups_dict, rank_list is not stored
|
||||
self.process_number_dict = self.create_process_numbers_for_logical_mesh()
|
||||
|
||||
def create_process_numbers_for_logical_mesh(self):
|
||||
'''
|
||||
Build 1d DeviceMesh in column-major(0) and row-major(1)
|
||||
for example:
|
||||
mesh_shape = (2,4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7]]
|
||||
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
|
||||
'''
|
||||
num_devices = reduce(operator.mul, self.mesh_shape, 1)
|
||||
process_numbers_dict = {}
|
||||
process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
|
||||
process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
|
||||
return process_numbers_dict
|
||||
|
||||
def mix_gather_cost(self, num_bytes):
|
||||
num_devices = reduce(operator.mul, self.mesh_shape, 1)
|
||||
return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
|
||||
|
Reference in New Issue
Block a user