diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 79cddeb7b..0dce2564c 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -16,8 +16,8 @@ from colossalai.auto_parallel.tensor_shard.solver import ( SolverOptions, StrategiesConstructor, ) +from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh -from colossalai.device.profile_alpha_beta import profile_alpha_beta from colossalai.fx.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py index 879b60c06..689189998 100644 --- a/colossalai/device/__init__.py +++ b/colossalai/device/__init__.py @@ -1,4 +1,4 @@ +from .alpha_beta_profiler import AlphaBetaProfiler from .calc_pipeline_strategy import alpa_dp -from .profile_alpha_beta import profile_alpha_beta -__all__ = ['profile_alpha_beta', 'alpa_dp'] +__all__ = ['AlphaBetaProfiler', 'alpa_dp'] diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py new file mode 100644 index 000000000..324acacb8 --- /dev/null +++ b/colossalai/device/alpha_beta_profiler.py @@ -0,0 +1,199 @@ +import math +import time +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist + +from colossalai.logging import get_dist_logger + +GB = int((1 << 30)) +BYTE = 4 +FRAMEWORK_LATENCY = 0 + + +class AlphaBetaProfiler: + ''' + Profile alpha and beta value for a given device list. + + Usage: + # Note: the environment of execution is supposed to be + # multi-process with multi-gpu in mpi style. + >>> physical_devices = [0, 1, 4, 5] + >>> ab_profiler = AlphaBetaProfiler(physical_devices) + >>> ab_dict = profiler.profile_ab() + >>> print(ab_dict) + {(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11), + (1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12), + (1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11), + (4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)} + ''' + + def __init__(self, + physical_devices: List[int], + ctype: str = 'a', + warmup: int = 5, + repeat: int = 25, + latency_iters: int = 5): + ''' + Args: + physical_devices: A list of device id, each element inside it is the global rank of that device. + ctype: 'a' for all-reduce, 'b' for broadcast. + warmup: Number of warmup iterations. + repeat: Number of iterations to measure. + latency_iters: Number of iterations to measure latency. + ''' + self.physical_devices = physical_devices + self.ctype = ctype + self.world_size = len(physical_devices) + self.warmup = warmup + self.repeat = repeat + self.latency_iters = latency_iters + self.process_group_dict = None + self._init_profiling() + + def _init_profiling(self): + # Create process group list based on its global rank + process_group_list = [] + for f_index in range(self.world_size - 1): + for b_index in range(f_index + 1, self.world_size): + process_group_list.append((self.physical_devices[f_index], self.physical_devices[b_index])) + + # Create process group dict which maps process group to its handler + process_group_dict = {} + for process_group in process_group_list: + pg_handler = dist.new_group(process_group) + process_group_dict[process_group] = pg_handler + + self.process_group_dict = process_group_dict + + def _profile(self, process_group, pg_handler, nbytes): + logger = get_dist_logger() + rank = dist.get_rank() + src_device_num = process_group[0] + world_size = len(process_group) + + device = torch.cuda.current_device() + buf = torch.randn(nbytes // 4).to(device) + + torch.cuda.synchronize() + # warmup + for _ in range(self.warmup): + if self.ctype == "a": + dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler) + elif self.ctype == "b": + dist.broadcast(buf, src=src_device_num, group=pg_handler) + torch.cuda.synchronize() + + dist.barrier(group=pg_handler) + begin = time.perf_counter() + for _ in range(self.repeat): + if self.ctype == "a": + dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler) + elif self.ctype == "b": + dist.broadcast(buf, src=src_device_num, group=pg_handler) + torch.cuda.synchronize() + end = time.perf_counter() + dist.barrier(group=pg_handler) + + if rank == src_device_num: + avg_time_s = (end - begin) / self.repeat - FRAMEWORK_LATENCY + alg_band = nbytes / avg_time_s + if self.ctype == "a": + # convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware. + bus_band = 2 * (world_size - 1) / world_size * alg_band + bus_band = alg_band + elif self.ctype == "b": + bus_band = alg_band + + logger.info( + f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s" + ) + return (avg_time_s, alg_band) + else: + # Just a placeholder + return (None, None) + + def profile_latency(self, process_group, pg_handler): + ''' + This function is used to profile the latency of the given process group with a series of bytes. + + Args: + process_group: A tuple of global rank of the process group. + pg_handler: The handler of the process group. + + Returns: + latency: None if the latency is not measured, otherwise the median of the latency_list. + ''' + latency_list = [] + for i in range(self.latency_iters): + nbytes = int(BYTE << i) + (t, _) = self._profile(process_group, pg_handler, nbytes) + latency_list.append(t) + + if latency_list[0] is None: + latency = None + else: + median_index = math.floor(self.latency_iters / 2) + latency = latency_list[median_index] + + return latency + + def profile_bandwidth(self, process_group, pg_handler, maxbytes): + ''' + This function is used to profile the bandwidth of the given process group. + + Args: + process_group: A tuple of global rank of the process group. + pg_handler: The handler of the process group. + ''' + (_, bandwidth) = self._profile(process_group, pg_handler, maxbytes) + return bandwidth + + def profile_ab(self): + ''' + This method is used to profiling the alpha and beta value for a given device list. + + Returns: + alpha_beta_dict: A dict which maps process group to its alpha and beta value. + ''' + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {} + rank = dist.get_rank() + + def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): + assert rank in process_group + device = torch.cuda.current_device() + rank_max_nbytes = torch.cuda.mem_get_info(device)[0] + rank_max_nbytes = torch.tensor(rank_max_nbytes, device=device) + dist.all_reduce(rank_max_nbytes, op=dist.ReduceOp.MIN, group=pg_handler) + max_nbytes = min(int(1 * GB), int(GB << int(math.log2(rank_max_nbytes.item() / GB)))) + return max_nbytes + + for process_group, pg_handler in self.process_group_dict.items(): + if rank not in process_group: + max_nbytes = None + alpha = None + bandwidth = None + else: + max_nbytes = get_max_nbytes(process_group, pg_handler) + alpha = self.profile_latency(process_group, pg_handler) + bandwidth = self.profile_bandwidth(process_group, pg_handler, maxbytes=max_nbytes) + + if bandwidth is None: + beta = None + else: + beta = 1 / bandwidth + + broadcast_list = [alpha, beta] + dist.broadcast_object_list(broadcast_list, src=process_group[0]) + alpha_beta_dict[process_group] = tuple(broadcast_list) + + # add symmetry pair to the apha_beta_dict + symmetry_ab_dict = {} + for process_group, alpha_beta_pair in alpha_beta_dict.items(): + symmetry_process_group = (process_group[1], process_group[0]) + symmetry_ab_dict[symmetry_process_group] = alpha_beta_pair + + alpha_beta_dict.update(symmetry_ab_dict) + + return alpha_beta_dict diff --git a/colossalai/device/profile_alpha_beta.py b/colossalai/device/profile_alpha_beta.py deleted file mode 100644 index 2d053ddbe..000000000 --- a/colossalai/device/profile_alpha_beta.py +++ /dev/null @@ -1,120 +0,0 @@ -import fcntl -import math -import os -import time - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -MB = int((1 << 10) * 1e3) -GB = int((1 << 20) * 1e3) -Byte = 4 -FRAMEWORK = 0 -NON_SENSE = (0.1, 0.1) - - -def printflock(*msgs): - """ solves multi-process interleaved print problem """ - with open(__file__, "r") as fh: - fcntl.flock(fh, fcntl.LOCK_EX) - try: - print(*msgs) - finally: - fcntl.flock(fh, fcntl.LOCK_UN) - - -def profile(device1d, nbytes, ctype): - warmup = 5 - repeat = 25 - rank = dist.get_rank() - src_device_num = device1d[0] - wsize = len(device1d) - group = dist.new_group(device1d) - - torch.cuda.set_device(rank) - device = torch.device("cuda", rank) - buf = torch.randn(nbytes // 4).to(device) - - torch.cuda.synchronize() - # warmup - for _ in range(warmup): - if ctype == "a": - dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group) - elif ctype == "b": - dist.broadcast(buf, src=src_device_num, group=group) - torch.cuda.synchronize() - - dist.barrier() - begin = time.perf_counter() - for _ in range(repeat): - if ctype == "a": - dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group) - elif ctype == "b": - dist.broadcast(buf, src=src_device_num, group=group) - torch.cuda.synchronize() - end = time.perf_counter() - dist.barrier() - - if rank == src_device_num: - avg_time_s = (end - begin) / repeat - FRAMEWORK - alg_band = nbytes / avg_time_s - if ctype == "b": - bus_band = alg_band - elif ctype == "a": - bus_band = 2 * (wsize - 1) / wsize * alg_band - print( - f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s" - ) - return (avg_time_s, alg_band) - else: - return NON_SENSE # Just a placeholder - - -def profile_latency(device1d, it=3, ctype="a"): - latency = [] - for i in range(it): - nbytes = int(Byte << i) - (t, _) = profile(device1d, nbytes, ctype) - latency.append(t) - return min(latency) - - -def profile_bandwidth(device1d, maxbytes, ctype="a"): - (_, bandwidth) = profile(device1d, maxbytes, ctype) - return bandwidth - - -def profile_ab(rank, *args): - wsize = int(torch.cuda.device_count()) - device1d = args[0] - return_dict = args[1] - ctype = args[2] - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '29020' - dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://', world_size=wsize, rank=rank) - - device = torch.device("cuda", rank) - max_nbytes = torch.tensor(torch.cuda.mem_get_info(device)[0]).to(device) - max_nbytes = min(int(4 * GB), int(GB << int(math.log2(max_nbytes.item() / GB)))) - - if rank == device1d[0]: - print(f"max_nbytes: {max_nbytes} B") - - alpha = profile_latency(device1d, it=5, ctype=ctype) - beta = 1 / profile_bandwidth(device1d, maxbytes=max_nbytes, ctype=ctype) - - if rank == device1d[0]: - print(f"alpha(us): {round(alpha * 1e6,2)}, beta(us/GB): {round(beta * 1e6 * GB,2)}") - return_dict[rank] = (alpha, beta) - - -def profile_alpha_beta(device1d): - assert torch.cuda.is_available() - assert len(device1d) > 0 and len(device1d) <= int(torch.cuda.device_count()) - - manager = mp.Manager() - return_dict = manager.dict() - ctype = "a" - mp.spawn(profile_ab, args=[device1d, return_dict, ctype], nprocs=int(torch.cuda.device_count())) - return return_dict[device1d[0]] diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py index 5b076fdf0..99abacd13 100644 --- a/tests/test_device/test_alpha_beta.py +++ b/tests/test_device/test_alpha_beta.py @@ -1,13 +1,32 @@ -import pytest +from functools import partial -from colossalai.device import profile_alpha_beta +import pytest +import torch.multiprocessing as mp + +from colossalai.device import AlphaBetaProfiler +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def check_alpha_beta(rank, physical_devices, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + profiler = AlphaBetaProfiler(physical_devices) + ab_dict = profiler.profile_ab() + for _, (alpha, beta) in ab_dict.items(): + assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10 @pytest.mark.skip(reason="Skip because assertion fails for CI devices") -def test_profile_alpha_beta(): - physical_devices = [0, 1, 2, 3] - (alpha, beta) = profile_alpha_beta(physical_devices) - assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10 +@pytest.mark.dist +@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@rerun_if_address_is_in_use() +def test_profile_alpha_beta(physical_devices): + world_size = 4 + run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__':