mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize * add megatron solution * update gpt autoparallel examples with latest api * adapt beta value to fit the current computation cost
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -15,7 +16,8 @@ class DeviceMesh:
|
||||
|
||||
Arguments:
|
||||
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
|
||||
mesh_shape (torch.Size): shape of logical view.
|
||||
logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
|
||||
mesh_shape (torch.Size, optional): shape of logical view.
|
||||
mesh_alpha (List[float], optional): coefficients used for computing
|
||||
communication cost (default: None)
|
||||
mesh_beta (List[float], optional): coefficients used for computing
|
||||
@@ -28,15 +30,21 @@ class DeviceMesh:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
physical_mesh_id,
|
||||
mesh_shape,
|
||||
mesh_alpha=None,
|
||||
mesh_beta=None,
|
||||
init_process_group=False,
|
||||
need_flatten=True):
|
||||
physical_mesh_id: torch.Tensor,
|
||||
mesh_shape: torch.Size = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
mesh_alpha: List[float] = None,
|
||||
mesh_beta: List[float] = None,
|
||||
init_process_group: bool = False,
|
||||
need_flatten: bool = True):
|
||||
self.physical_mesh_id = physical_mesh_id
|
||||
self.mesh_shape = mesh_shape
|
||||
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
|
||||
if logical_mesh_id is None:
|
||||
self.mesh_shape = mesh_shape
|
||||
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
|
||||
else:
|
||||
self._logical_mesh_id = logical_mesh_id
|
||||
self.mesh_shape = self._logical_mesh_id.shape
|
||||
|
||||
# map global rank into logical rank
|
||||
self.convert_map = {}
|
||||
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
|
||||
@@ -54,8 +62,8 @@ class DeviceMesh:
|
||||
if self.need_flatten and self._logical_mesh_id.dim() > 1:
|
||||
self.flatten_device_mesh = self.flatten()
|
||||
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
|
||||
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
self.mesh_beta)
|
||||
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
# self.mesh_beta)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
Reference in New Issue
Block a user