[autoparallel] integrate device mesh initialization into autoparallelize (#2393)

* [autoparallel] integrate device mesh initialization into autoparallelize

* add megatron solution

* update gpt autoparallel examples with latest api

* adapt beta value to fit the current computation cost
This commit is contained in:
YuliangLiu0306 2023-01-11 14:03:49 +08:00 committed by GitHub
parent c72c827e95
commit 2731531bc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 64 additions and 51 deletions

View File

@ -59,18 +59,6 @@ def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader,
pass pass
def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
'''
This method is used to search the best logical mesh shape for the given world size
based on the alpha_beta_dict.
For example:
if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
'''
# TODO: implement this function
return (world_size, 1)
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]): def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
''' '''
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
@ -127,39 +115,56 @@ def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh
def initialize_device_mesh(world_size: int = -1, def initialize_device_mesh(world_size: int = -1,
physical_devices: List[int] = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None): logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None):
''' '''
This method is used to initialize the device mesh. This method is used to initialize the device mesh.
Args: Args:
world_size(optional): the size of device mesh. If the world_size is -1, world_size: the size of device mesh. If the world_size is -1,
the world size will be set to the number of GPUs in the current machine. the world size will be set to the number of GPUs in the current machine.
physical_devices: the physical devices used to initialize the device mesh.
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
generated by profile_alpha_beta function. generated by profile_alpha_beta function.
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be mesh shape.
generated by search_best_logical_mesh_shape function. logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
''' '''
# if world_size is not set, use the world size from torch.distributed # if world_size is not set, use the world size from torch.distributed
if world_size == -1: if world_size == -1:
world_size = dist.get_world_size() world_size = dist.get_world_size()
device1d = [i for i in range(world_size)]
if physical_devices is None:
physical_devices = [i for i in range(world_size)]
physical_mesh = torch.tensor(physical_devices)
if alpha_beta_dict is None: if alpha_beta_dict is None:
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device # if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
alpha_beta_dict = profile_alpha_beta(device1d) ab_profiler = AlphaBetaProfiler(physical_devices)
alpha_beta_dict = ab_profiler.alpha_beta_dict
else:
ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict)
if logical_mesh_shape is None: if logical_mesh_shape is None and logical_mesh_id is None:
# search for the best logical mesh shape # search for the best logical mesh shape
logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict) logical_mesh_id = ab_profiler.search_best_logical_mesh()
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
logical_mesh_shape = logical_mesh_id.shape
# extract alpha and beta values for the chosen logical mesh shape # extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape) mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh()
physical_mesh = torch.tensor(device1d)
elif logical_mesh_shape is not None and logical_mesh_id is None:
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id,
mesh_alpha=mesh_alpha, mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta, mesh_beta=mesh_beta,
init_process_group=True) init_process_group=True)
@ -224,6 +229,7 @@ def autoparallelize(model: nn.Module,
data_process_func: callable = None, data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None, logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
save_solver_solution: bool = False, save_solver_solution: bool = False,
load_solver_solution: bool = False, load_solver_solution: bool = False,
solver_solution_path: str = None, solver_solution_path: str = None,
@ -245,6 +251,7 @@ def autoparallelize(model: nn.Module,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function. generated by search_best_logical_mesh_shape function.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path. to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
@ -254,7 +261,9 @@ def autoparallelize(model: nn.Module,
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity. the memory budget will be infinity.
''' '''
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape) device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
logical_mesh_shape=logical_mesh_shape,
logical_mesh_id=logical_mesh_id)
if meta_args is None: if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func) meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
@ -263,7 +272,7 @@ def autoparallelize(model: nn.Module,
device_mesh, device_mesh,
save_solver_solution=save_solver_solution, save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution, load_solver_solution=load_solver_solution,
solver_solution_path=solver_solution_path, solution_path=solver_solution_path,
return_solution=return_solution, return_solution=return_solution,
memory_budget=memory_budget) memory_budget=memory_budget)

View File

@ -381,6 +381,8 @@ class AlphaBetaProfiler:
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group) first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group) second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
mesh_alpha = [first_latency, second_latency] mesh_alpha = [first_latency, second_latency]
mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth] # The beta values have been enlarged by 1e10 times temporarilly because the computation cost
# is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]
return mesh_alpha, mesh_beta return mesh_alpha, mesh_beta

View File

@ -1,5 +1,6 @@
import operator import operator
from functools import reduce from functools import reduce
from typing import List, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -15,7 +16,8 @@ class DeviceMesh:
Arguments: Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank. physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view. logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
mesh_shape (torch.Size, optional): shape of logical view.
mesh_alpha (List[float], optional): coefficients used for computing mesh_alpha (List[float], optional): coefficients used for computing
communication cost (default: None) communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing mesh_beta (List[float], optional): coefficients used for computing
@ -28,15 +30,21 @@ class DeviceMesh:
""" """
def __init__(self, def __init__(self,
physical_mesh_id, physical_mesh_id: torch.Tensor,
mesh_shape, mesh_shape: torch.Size = None,
mesh_alpha=None, logical_mesh_id: torch.Tensor = None,
mesh_beta=None, mesh_alpha: List[float] = None,
init_process_group=False, mesh_beta: List[float] = None,
need_flatten=True): init_process_group: bool = False,
need_flatten: bool = True):
self.physical_mesh_id = physical_mesh_id self.physical_mesh_id = physical_mesh_id
if logical_mesh_id is None:
self.mesh_shape = mesh_shape self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape
# map global rank into logical rank # map global rank into logical rank
self.convert_map = {} self.convert_map = {}
self._global_rank_to_logical_rank_map(self._logical_mesh_id, []) self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
@ -54,8 +62,8 @@ class DeviceMesh:
if self.need_flatten and self._logical_mesh_id.dim() > 1: if self.need_flatten and self._logical_mesh_id.dim() > 1:
self.flatten_device_mesh = self.flatten() self.flatten_device_mesh = self.flatten()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
self.mesh_beta) # self.mesh_beta)
@property @property
def shape(self): def shape(self):

View File

@ -16,14 +16,14 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch_from_torch from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
BATCH_SIZE = 8 BATCH_SIZE = 16
SEQ_LENGTH = 128 SEQ_LENGTH = 1024
HIDDEN_DIM = 3072 HIDDEN_DIM = 4096
NUM_HEADS = 16 NUM_HEADS = 16
NUM_LAYERS = 1 NUM_LAYERS = 4
VOCAB_SIZE = 50257 VOCAB_SIZE = 50257
NUM_STEPS = 10 NUM_STEPS = 10
FP16 = False FP16 = True
def get_cpu_mem(): def get_cpu_mem():
@ -40,7 +40,7 @@ def get_mem_info(prefix=''):
def get_tflops(model_numel, batch_size, seq_len, step_time): def get_tflops(model_numel, batch_size, seq_len, step_time):
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4 return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 8
# Randomly Generated Data # Randomly Generated Data
@ -66,13 +66,7 @@ def main():
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
} }
# Both device mesh initialization and model initialization will be integrated into autoparallelize gm, solution = autoparallelize(model, meta_input_sample, return_solution=True)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# Enable auto-parallel
gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True)
# print solution on rank 0 # print solution on rank 0
if gpc.get_global_rank() == 0: if gpc.get_global_rank() == 0: