[device] alpha beta profiler (#2311)

* [device] alpha beta profiler

* add usage

* fix variable name
This commit is contained in:
YuliangLiu0306 2023-01-05 16:39:55 +08:00 committed by GitHub
parent f1bc2418c4
commit 9c9246c0d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 227 additions and 129 deletions

View File

@ -16,8 +16,8 @@ from colossalai.auto_parallel.tensor_shard.solver import (
SolverOptions, SolverOptions,
StrategiesConstructor, StrategiesConstructor,
) )
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.device.profile_alpha_beta import profile_alpha_beta
from colossalai.fx.tracer import ColoTracer from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec

View File

@ -1,4 +1,4 @@
from .alpha_beta_profiler import AlphaBetaProfiler
from .calc_pipeline_strategy import alpa_dp from .calc_pipeline_strategy import alpa_dp
from .profile_alpha_beta import profile_alpha_beta
__all__ = ['profile_alpha_beta', 'alpa_dp'] __all__ = ['AlphaBetaProfiler', 'alpa_dp']

View File

@ -0,0 +1,199 @@
import math
import time
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
from colossalai.logging import get_dist_logger
GB = int((1 << 30))
BYTE = 4
FRAMEWORK_LATENCY = 0
class AlphaBetaProfiler:
'''
Profile alpha and beta value for a given device list.
Usage:
# Note: the environment of execution is supposed to be
# multi-process with multi-gpu in mpi style.
>>> physical_devices = [0, 1, 4, 5]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> ab_dict = profiler.profile_ab()
>>> print(ab_dict)
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
'''
def __init__(self,
physical_devices: List[int],
ctype: str = 'a',
warmup: int = 5,
repeat: int = 25,
latency_iters: int = 5):
'''
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
ctype: 'a' for all-reduce, 'b' for broadcast.
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency.
'''
self.physical_devices = physical_devices
self.ctype = ctype
self.world_size = len(physical_devices)
self.warmup = warmup
self.repeat = repeat
self.latency_iters = latency_iters
self.process_group_dict = None
self._init_profiling()
def _init_profiling(self):
# Create process group list based on its global rank
process_group_list = []
for f_index in range(self.world_size - 1):
for b_index in range(f_index + 1, self.world_size):
process_group_list.append((self.physical_devices[f_index], self.physical_devices[b_index]))
# Create process group dict which maps process group to its handler
process_group_dict = {}
for process_group in process_group_list:
pg_handler = dist.new_group(process_group)
process_group_dict[process_group] = pg_handler
self.process_group_dict = process_group_dict
def _profile(self, process_group, pg_handler, nbytes):
logger = get_dist_logger()
rank = dist.get_rank()
src_device_num = process_group[0]
world_size = len(process_group)
device = torch.cuda.current_device()
buf = torch.randn(nbytes // 4).to(device)
torch.cuda.synchronize()
# warmup
for _ in range(self.warmup):
if self.ctype == "a":
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler)
elif self.ctype == "b":
dist.broadcast(buf, src=src_device_num, group=pg_handler)
torch.cuda.synchronize()
dist.barrier(group=pg_handler)
begin = time.perf_counter()
for _ in range(self.repeat):
if self.ctype == "a":
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler)
elif self.ctype == "b":
dist.broadcast(buf, src=src_device_num, group=pg_handler)
torch.cuda.synchronize()
end = time.perf_counter()
dist.barrier(group=pg_handler)
if rank == src_device_num:
avg_time_s = (end - begin) / self.repeat - FRAMEWORK_LATENCY
alg_band = nbytes / avg_time_s
if self.ctype == "a":
# convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware.
bus_band = 2 * (world_size - 1) / world_size * alg_band
bus_band = alg_band
elif self.ctype == "b":
bus_band = alg_band
logger.info(
f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s"
)
return (avg_time_s, alg_band)
else:
# Just a placeholder
return (None, None)
def profile_latency(self, process_group, pg_handler):
'''
This function is used to profile the latency of the given process group with a series of bytes.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list.
'''
latency_list = []
for i in range(self.latency_iters):
nbytes = int(BYTE << i)
(t, _) = self._profile(process_group, pg_handler, nbytes)
latency_list.append(t)
if latency_list[0] is None:
latency = None
else:
median_index = math.floor(self.latency_iters / 2)
latency = latency_list[median_index]
return latency
def profile_bandwidth(self, process_group, pg_handler, maxbytes):
'''
This function is used to profile the bandwidth of the given process group.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
'''
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
return bandwidth
def profile_ab(self):
'''
This method is used to profiling the alpha and beta value for a given device list.
Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
'''
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank()
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group
device = torch.cuda.current_device()
rank_max_nbytes = torch.cuda.mem_get_info(device)[0]
rank_max_nbytes = torch.tensor(rank_max_nbytes, device=device)
dist.all_reduce(rank_max_nbytes, op=dist.ReduceOp.MIN, group=pg_handler)
max_nbytes = min(int(1 * GB), int(GB << int(math.log2(rank_max_nbytes.item() / GB))))
return max_nbytes
for process_group, pg_handler in self.process_group_dict.items():
if rank not in process_group:
max_nbytes = None
alpha = None
bandwidth = None
else:
max_nbytes = get_max_nbytes(process_group, pg_handler)
alpha = self.profile_latency(process_group, pg_handler)
bandwidth = self.profile_bandwidth(process_group, pg_handler, maxbytes=max_nbytes)
if bandwidth is None:
beta = None
else:
beta = 1 / bandwidth
broadcast_list = [alpha, beta]
dist.broadcast_object_list(broadcast_list, src=process_group[0])
alpha_beta_dict[process_group] = tuple(broadcast_list)
# add symmetry pair to the apha_beta_dict
symmetry_ab_dict = {}
for process_group, alpha_beta_pair in alpha_beta_dict.items():
symmetry_process_group = (process_group[1], process_group[0])
symmetry_ab_dict[symmetry_process_group] = alpha_beta_pair
alpha_beta_dict.update(symmetry_ab_dict)
return alpha_beta_dict

View File

@ -1,120 +0,0 @@
import fcntl
import math
import os
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
MB = int((1 << 10) * 1e3)
GB = int((1 << 20) * 1e3)
Byte = 4
FRAMEWORK = 0
NON_SENSE = (0.1, 0.1)
def printflock(*msgs):
""" solves multi-process interleaved print problem """
with open(__file__, "r") as fh:
fcntl.flock(fh, fcntl.LOCK_EX)
try:
print(*msgs)
finally:
fcntl.flock(fh, fcntl.LOCK_UN)
def profile(device1d, nbytes, ctype):
warmup = 5
repeat = 25
rank = dist.get_rank()
src_device_num = device1d[0]
wsize = len(device1d)
group = dist.new_group(device1d)
torch.cuda.set_device(rank)
device = torch.device("cuda", rank)
buf = torch.randn(nbytes // 4).to(device)
torch.cuda.synchronize()
# warmup
for _ in range(warmup):
if ctype == "a":
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group)
elif ctype == "b":
dist.broadcast(buf, src=src_device_num, group=group)
torch.cuda.synchronize()
dist.barrier()
begin = time.perf_counter()
for _ in range(repeat):
if ctype == "a":
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group)
elif ctype == "b":
dist.broadcast(buf, src=src_device_num, group=group)
torch.cuda.synchronize()
end = time.perf_counter()
dist.barrier()
if rank == src_device_num:
avg_time_s = (end - begin) / repeat - FRAMEWORK
alg_band = nbytes / avg_time_s
if ctype == "b":
bus_band = alg_band
elif ctype == "a":
bus_band = 2 * (wsize - 1) / wsize * alg_band
print(
f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s"
)
return (avg_time_s, alg_band)
else:
return NON_SENSE # Just a placeholder
def profile_latency(device1d, it=3, ctype="a"):
latency = []
for i in range(it):
nbytes = int(Byte << i)
(t, _) = profile(device1d, nbytes, ctype)
latency.append(t)
return min(latency)
def profile_bandwidth(device1d, maxbytes, ctype="a"):
(_, bandwidth) = profile(device1d, maxbytes, ctype)
return bandwidth
def profile_ab(rank, *args):
wsize = int(torch.cuda.device_count())
device1d = args[0]
return_dict = args[1]
ctype = args[2]
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29020'
dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://', world_size=wsize, rank=rank)
device = torch.device("cuda", rank)
max_nbytes = torch.tensor(torch.cuda.mem_get_info(device)[0]).to(device)
max_nbytes = min(int(4 * GB), int(GB << int(math.log2(max_nbytes.item() / GB))))
if rank == device1d[0]:
print(f"max_nbytes: {max_nbytes} B")
alpha = profile_latency(device1d, it=5, ctype=ctype)
beta = 1 / profile_bandwidth(device1d, maxbytes=max_nbytes, ctype=ctype)
if rank == device1d[0]:
print(f"alpha(us): {round(alpha * 1e6,2)}, beta(us/GB): {round(beta * 1e6 * GB,2)}")
return_dict[rank] = (alpha, beta)
def profile_alpha_beta(device1d):
assert torch.cuda.is_available()
assert len(device1d) > 0 and len(device1d) <= int(torch.cuda.device_count())
manager = mp.Manager()
return_dict = manager.dict()
ctype = "a"
mp.spawn(profile_ab, args=[device1d, return_dict, ctype], nprocs=int(torch.cuda.device_count()))
return return_dict[device1d[0]]

View File

@ -1,13 +1,32 @@
import pytest from functools import partial
from colossalai.device import profile_alpha_beta import pytest
import torch.multiprocessing as mp
from colossalai.device import AlphaBetaProfiler
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
def check_alpha_beta(rank, physical_devices, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
profiler = AlphaBetaProfiler(physical_devices)
ab_dict = profiler.profile_ab()
for _, (alpha, beta) in ab_dict.items():
assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10
@pytest.mark.skip(reason="Skip because assertion fails for CI devices") @pytest.mark.skip(reason="Skip because assertion fails for CI devices")
def test_profile_alpha_beta(): @pytest.mark.dist
physical_devices = [0, 1, 2, 3] @parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]])
(alpha, beta) = profile_alpha_beta(physical_devices) @rerun_if_address_is_in_use()
assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10 def test_profile_alpha_beta(physical_devices):
world_size = 4
run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':