mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize * add megatron solution * update gpt autoparallel examples with latest api * adapt beta value to fit the current computation cost
This commit is contained in:
@@ -59,18 +59,6 @@ def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader,
|
||||
pass
|
||||
|
||||
|
||||
def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
|
||||
'''
|
||||
This method is used to search the best logical mesh shape for the given world size
|
||||
based on the alpha_beta_dict.
|
||||
|
||||
For example:
|
||||
if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
|
||||
'''
|
||||
# TODO: implement this function
|
||||
return (world_size, 1)
|
||||
|
||||
|
||||
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
|
||||
'''
|
||||
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
|
||||
@@ -127,39 +115,56 @@ def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh
|
||||
|
||||
|
||||
def initialize_device_mesh(world_size: int = -1,
|
||||
physical_devices: List[int] = None,
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None):
|
||||
logical_mesh_shape: Tuple[int] = None,
|
||||
logical_mesh_id: torch.Tensor = None):
|
||||
'''
|
||||
This method is used to initialize the device mesh.
|
||||
|
||||
Args:
|
||||
world_size(optional): the size of device mesh. If the world_size is -1,
|
||||
world_size: the size of device mesh. If the world_size is -1,
|
||||
the world size will be set to the number of GPUs in the current machine.
|
||||
physical_devices: the physical devices used to initialize the device mesh.
|
||||
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
|
||||
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
|
||||
generated by profile_alpha_beta function.
|
||||
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
||||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
mesh shape.
|
||||
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
||||
'''
|
||||
# if world_size is not set, use the world size from torch.distributed
|
||||
if world_size == -1:
|
||||
world_size = dist.get_world_size()
|
||||
device1d = [i for i in range(world_size)]
|
||||
|
||||
if physical_devices is None:
|
||||
physical_devices = [i for i in range(world_size)]
|
||||
physical_mesh = torch.tensor(physical_devices)
|
||||
|
||||
if alpha_beta_dict is None:
|
||||
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
|
||||
alpha_beta_dict = profile_alpha_beta(device1d)
|
||||
ab_profiler = AlphaBetaProfiler(physical_devices)
|
||||
alpha_beta_dict = ab_profiler.alpha_beta_dict
|
||||
else:
|
||||
ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict)
|
||||
|
||||
if logical_mesh_shape is None:
|
||||
if logical_mesh_shape is None and logical_mesh_id is None:
|
||||
# search for the best logical mesh shape
|
||||
logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict)
|
||||
logical_mesh_id = ab_profiler.search_best_logical_mesh()
|
||||
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
|
||||
logical_mesh_shape = logical_mesh_id.shape
|
||||
|
||||
# extract alpha and beta values for the chosen logical mesh shape
|
||||
mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh()
|
||||
|
||||
elif logical_mesh_shape is not None and logical_mesh_id is None:
|
||||
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)
|
||||
|
||||
# extract alpha and beta values for the chosen logical mesh shape
|
||||
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
|
||||
|
||||
# extract alpha and beta values for the chosen logical mesh shape
|
||||
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape)
|
||||
physical_mesh = torch.tensor(device1d)
|
||||
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
|
||||
mesh_shape=logical_mesh_shape,
|
||||
logical_mesh_id=logical_mesh_id,
|
||||
mesh_alpha=mesh_alpha,
|
||||
mesh_beta=mesh_beta,
|
||||
init_process_group=True)
|
||||
@@ -224,6 +229,7 @@ def autoparallelize(model: nn.Module,
|
||||
data_process_func: callable = None,
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solver_solution_path: str = None,
|
||||
@@ -245,6 +251,7 @@ def autoparallelize(model: nn.Module,
|
||||
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
||||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
@@ -254,7 +261,9 @@ def autoparallelize(model: nn.Module,
|
||||
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
||||
the memory budget will be infinity.
|
||||
'''
|
||||
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
|
||||
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
|
||||
logical_mesh_shape=logical_mesh_shape,
|
||||
logical_mesh_id=logical_mesh_id)
|
||||
if meta_args is None:
|
||||
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
|
||||
|
||||
@@ -263,7 +272,7 @@ def autoparallelize(model: nn.Module,
|
||||
device_mesh,
|
||||
save_solver_solution=save_solver_solution,
|
||||
load_solver_solution=load_solver_solution,
|
||||
solver_solution_path=solver_solution_path,
|
||||
solution_path=solver_solution_path,
|
||||
return_solution=return_solution,
|
||||
memory_budget=memory_budget)
|
||||
|
||||
|
Reference in New Issue
Block a user