mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-06 10:34:23 +00:00
[device] alpha beta profiler (#2311)
* [device] alpha beta profiler * add usage * fix variable name
This commit is contained in:
parent
f1bc2418c4
commit
9c9246c0d9
@ -16,8 +16,8 @@ from colossalai.auto_parallel.tensor_shard.solver import (
|
|||||||
SolverOptions,
|
SolverOptions,
|
||||||
StrategiesConstructor,
|
StrategiesConstructor,
|
||||||
)
|
)
|
||||||
|
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.device.profile_alpha_beta import profile_alpha_beta
|
|
||||||
from colossalai.fx.tracer import ColoTracer
|
from colossalai.fx.tracer import ColoTracer
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
|
from .alpha_beta_profiler import AlphaBetaProfiler
|
||||||
from .calc_pipeline_strategy import alpa_dp
|
from .calc_pipeline_strategy import alpa_dp
|
||||||
from .profile_alpha_beta import profile_alpha_beta
|
|
||||||
|
|
||||||
__all__ = ['profile_alpha_beta', 'alpa_dp']
|
__all__ = ['AlphaBetaProfiler', 'alpa_dp']
|
||||||
|
199
colossalai/device/alpha_beta_profiler.py
Normal file
199
colossalai/device/alpha_beta_profiler.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
GB = int((1 << 30))
|
||||||
|
BYTE = 4
|
||||||
|
FRAMEWORK_LATENCY = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AlphaBetaProfiler:
|
||||||
|
'''
|
||||||
|
Profile alpha and beta value for a given device list.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Note: the environment of execution is supposed to be
|
||||||
|
# multi-process with multi-gpu in mpi style.
|
||||||
|
>>> physical_devices = [0, 1, 4, 5]
|
||||||
|
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
|
||||||
|
>>> ab_dict = profiler.profile_ab()
|
||||||
|
>>> print(ab_dict)
|
||||||
|
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
|
||||||
|
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
|
||||||
|
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
|
||||||
|
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
physical_devices: List[int],
|
||||||
|
ctype: str = 'a',
|
||||||
|
warmup: int = 5,
|
||||||
|
repeat: int = 25,
|
||||||
|
latency_iters: int = 5):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
physical_devices: A list of device id, each element inside it is the global rank of that device.
|
||||||
|
ctype: 'a' for all-reduce, 'b' for broadcast.
|
||||||
|
warmup: Number of warmup iterations.
|
||||||
|
repeat: Number of iterations to measure.
|
||||||
|
latency_iters: Number of iterations to measure latency.
|
||||||
|
'''
|
||||||
|
self.physical_devices = physical_devices
|
||||||
|
self.ctype = ctype
|
||||||
|
self.world_size = len(physical_devices)
|
||||||
|
self.warmup = warmup
|
||||||
|
self.repeat = repeat
|
||||||
|
self.latency_iters = latency_iters
|
||||||
|
self.process_group_dict = None
|
||||||
|
self._init_profiling()
|
||||||
|
|
||||||
|
def _init_profiling(self):
|
||||||
|
# Create process group list based on its global rank
|
||||||
|
process_group_list = []
|
||||||
|
for f_index in range(self.world_size - 1):
|
||||||
|
for b_index in range(f_index + 1, self.world_size):
|
||||||
|
process_group_list.append((self.physical_devices[f_index], self.physical_devices[b_index]))
|
||||||
|
|
||||||
|
# Create process group dict which maps process group to its handler
|
||||||
|
process_group_dict = {}
|
||||||
|
for process_group in process_group_list:
|
||||||
|
pg_handler = dist.new_group(process_group)
|
||||||
|
process_group_dict[process_group] = pg_handler
|
||||||
|
|
||||||
|
self.process_group_dict = process_group_dict
|
||||||
|
|
||||||
|
def _profile(self, process_group, pg_handler, nbytes):
|
||||||
|
logger = get_dist_logger()
|
||||||
|
rank = dist.get_rank()
|
||||||
|
src_device_num = process_group[0]
|
||||||
|
world_size = len(process_group)
|
||||||
|
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
buf = torch.randn(nbytes // 4).to(device)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# warmup
|
||||||
|
for _ in range(self.warmup):
|
||||||
|
if self.ctype == "a":
|
||||||
|
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler)
|
||||||
|
elif self.ctype == "b":
|
||||||
|
dist.broadcast(buf, src=src_device_num, group=pg_handler)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
dist.barrier(group=pg_handler)
|
||||||
|
begin = time.perf_counter()
|
||||||
|
for _ in range(self.repeat):
|
||||||
|
if self.ctype == "a":
|
||||||
|
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler)
|
||||||
|
elif self.ctype == "b":
|
||||||
|
dist.broadcast(buf, src=src_device_num, group=pg_handler)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.perf_counter()
|
||||||
|
dist.barrier(group=pg_handler)
|
||||||
|
|
||||||
|
if rank == src_device_num:
|
||||||
|
avg_time_s = (end - begin) / self.repeat - FRAMEWORK_LATENCY
|
||||||
|
alg_band = nbytes / avg_time_s
|
||||||
|
if self.ctype == "a":
|
||||||
|
# convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware.
|
||||||
|
bus_band = 2 * (world_size - 1) / world_size * alg_band
|
||||||
|
bus_band = alg_band
|
||||||
|
elif self.ctype == "b":
|
||||||
|
bus_band = alg_band
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s"
|
||||||
|
)
|
||||||
|
return (avg_time_s, alg_band)
|
||||||
|
else:
|
||||||
|
# Just a placeholder
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
def profile_latency(self, process_group, pg_handler):
|
||||||
|
'''
|
||||||
|
This function is used to profile the latency of the given process group with a series of bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_group: A tuple of global rank of the process group.
|
||||||
|
pg_handler: The handler of the process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
latency: None if the latency is not measured, otherwise the median of the latency_list.
|
||||||
|
'''
|
||||||
|
latency_list = []
|
||||||
|
for i in range(self.latency_iters):
|
||||||
|
nbytes = int(BYTE << i)
|
||||||
|
(t, _) = self._profile(process_group, pg_handler, nbytes)
|
||||||
|
latency_list.append(t)
|
||||||
|
|
||||||
|
if latency_list[0] is None:
|
||||||
|
latency = None
|
||||||
|
else:
|
||||||
|
median_index = math.floor(self.latency_iters / 2)
|
||||||
|
latency = latency_list[median_index]
|
||||||
|
|
||||||
|
return latency
|
||||||
|
|
||||||
|
def profile_bandwidth(self, process_group, pg_handler, maxbytes):
|
||||||
|
'''
|
||||||
|
This function is used to profile the bandwidth of the given process group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_group: A tuple of global rank of the process group.
|
||||||
|
pg_handler: The handler of the process group.
|
||||||
|
'''
|
||||||
|
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
|
||||||
|
return bandwidth
|
||||||
|
|
||||||
|
def profile_ab(self):
|
||||||
|
'''
|
||||||
|
This method is used to profiling the alpha and beta value for a given device list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
|
||||||
|
'''
|
||||||
|
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
|
||||||
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
|
||||||
|
assert rank in process_group
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
rank_max_nbytes = torch.cuda.mem_get_info(device)[0]
|
||||||
|
rank_max_nbytes = torch.tensor(rank_max_nbytes, device=device)
|
||||||
|
dist.all_reduce(rank_max_nbytes, op=dist.ReduceOp.MIN, group=pg_handler)
|
||||||
|
max_nbytes = min(int(1 * GB), int(GB << int(math.log2(rank_max_nbytes.item() / GB))))
|
||||||
|
return max_nbytes
|
||||||
|
|
||||||
|
for process_group, pg_handler in self.process_group_dict.items():
|
||||||
|
if rank not in process_group:
|
||||||
|
max_nbytes = None
|
||||||
|
alpha = None
|
||||||
|
bandwidth = None
|
||||||
|
else:
|
||||||
|
max_nbytes = get_max_nbytes(process_group, pg_handler)
|
||||||
|
alpha = self.profile_latency(process_group, pg_handler)
|
||||||
|
bandwidth = self.profile_bandwidth(process_group, pg_handler, maxbytes=max_nbytes)
|
||||||
|
|
||||||
|
if bandwidth is None:
|
||||||
|
beta = None
|
||||||
|
else:
|
||||||
|
beta = 1 / bandwidth
|
||||||
|
|
||||||
|
broadcast_list = [alpha, beta]
|
||||||
|
dist.broadcast_object_list(broadcast_list, src=process_group[0])
|
||||||
|
alpha_beta_dict[process_group] = tuple(broadcast_list)
|
||||||
|
|
||||||
|
# add symmetry pair to the apha_beta_dict
|
||||||
|
symmetry_ab_dict = {}
|
||||||
|
for process_group, alpha_beta_pair in alpha_beta_dict.items():
|
||||||
|
symmetry_process_group = (process_group[1], process_group[0])
|
||||||
|
symmetry_ab_dict[symmetry_process_group] = alpha_beta_pair
|
||||||
|
|
||||||
|
alpha_beta_dict.update(symmetry_ab_dict)
|
||||||
|
|
||||||
|
return alpha_beta_dict
|
@ -1,120 +0,0 @@
|
|||||||
import fcntl
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
|
|
||||||
MB = int((1 << 10) * 1e3)
|
|
||||||
GB = int((1 << 20) * 1e3)
|
|
||||||
Byte = 4
|
|
||||||
FRAMEWORK = 0
|
|
||||||
NON_SENSE = (0.1, 0.1)
|
|
||||||
|
|
||||||
|
|
||||||
def printflock(*msgs):
|
|
||||||
""" solves multi-process interleaved print problem """
|
|
||||||
with open(__file__, "r") as fh:
|
|
||||||
fcntl.flock(fh, fcntl.LOCK_EX)
|
|
||||||
try:
|
|
||||||
print(*msgs)
|
|
||||||
finally:
|
|
||||||
fcntl.flock(fh, fcntl.LOCK_UN)
|
|
||||||
|
|
||||||
|
|
||||||
def profile(device1d, nbytes, ctype):
|
|
||||||
warmup = 5
|
|
||||||
repeat = 25
|
|
||||||
rank = dist.get_rank()
|
|
||||||
src_device_num = device1d[0]
|
|
||||||
wsize = len(device1d)
|
|
||||||
group = dist.new_group(device1d)
|
|
||||||
|
|
||||||
torch.cuda.set_device(rank)
|
|
||||||
device = torch.device("cuda", rank)
|
|
||||||
buf = torch.randn(nbytes // 4).to(device)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
# warmup
|
|
||||||
for _ in range(warmup):
|
|
||||||
if ctype == "a":
|
|
||||||
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group)
|
|
||||||
elif ctype == "b":
|
|
||||||
dist.broadcast(buf, src=src_device_num, group=group)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
dist.barrier()
|
|
||||||
begin = time.perf_counter()
|
|
||||||
for _ in range(repeat):
|
|
||||||
if ctype == "a":
|
|
||||||
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group)
|
|
||||||
elif ctype == "b":
|
|
||||||
dist.broadcast(buf, src=src_device_num, group=group)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = time.perf_counter()
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
if rank == src_device_num:
|
|
||||||
avg_time_s = (end - begin) / repeat - FRAMEWORK
|
|
||||||
alg_band = nbytes / avg_time_s
|
|
||||||
if ctype == "b":
|
|
||||||
bus_band = alg_band
|
|
||||||
elif ctype == "a":
|
|
||||||
bus_band = 2 * (wsize - 1) / wsize * alg_band
|
|
||||||
print(
|
|
||||||
f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s"
|
|
||||||
)
|
|
||||||
return (avg_time_s, alg_band)
|
|
||||||
else:
|
|
||||||
return NON_SENSE # Just a placeholder
|
|
||||||
|
|
||||||
|
|
||||||
def profile_latency(device1d, it=3, ctype="a"):
|
|
||||||
latency = []
|
|
||||||
for i in range(it):
|
|
||||||
nbytes = int(Byte << i)
|
|
||||||
(t, _) = profile(device1d, nbytes, ctype)
|
|
||||||
latency.append(t)
|
|
||||||
return min(latency)
|
|
||||||
|
|
||||||
|
|
||||||
def profile_bandwidth(device1d, maxbytes, ctype="a"):
|
|
||||||
(_, bandwidth) = profile(device1d, maxbytes, ctype)
|
|
||||||
return bandwidth
|
|
||||||
|
|
||||||
|
|
||||||
def profile_ab(rank, *args):
|
|
||||||
wsize = int(torch.cuda.device_count())
|
|
||||||
device1d = args[0]
|
|
||||||
return_dict = args[1]
|
|
||||||
ctype = args[2]
|
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
|
||||||
os.environ['MASTER_PORT'] = '29020'
|
|
||||||
dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://', world_size=wsize, rank=rank)
|
|
||||||
|
|
||||||
device = torch.device("cuda", rank)
|
|
||||||
max_nbytes = torch.tensor(torch.cuda.mem_get_info(device)[0]).to(device)
|
|
||||||
max_nbytes = min(int(4 * GB), int(GB << int(math.log2(max_nbytes.item() / GB))))
|
|
||||||
|
|
||||||
if rank == device1d[0]:
|
|
||||||
print(f"max_nbytes: {max_nbytes} B")
|
|
||||||
|
|
||||||
alpha = profile_latency(device1d, it=5, ctype=ctype)
|
|
||||||
beta = 1 / profile_bandwidth(device1d, maxbytes=max_nbytes, ctype=ctype)
|
|
||||||
|
|
||||||
if rank == device1d[0]:
|
|
||||||
print(f"alpha(us): {round(alpha * 1e6,2)}, beta(us/GB): {round(beta * 1e6 * GB,2)}")
|
|
||||||
return_dict[rank] = (alpha, beta)
|
|
||||||
|
|
||||||
|
|
||||||
def profile_alpha_beta(device1d):
|
|
||||||
assert torch.cuda.is_available()
|
|
||||||
assert len(device1d) > 0 and len(device1d) <= int(torch.cuda.device_count())
|
|
||||||
|
|
||||||
manager = mp.Manager()
|
|
||||||
return_dict = manager.dict()
|
|
||||||
ctype = "a"
|
|
||||||
mp.spawn(profile_ab, args=[device1d, return_dict, ctype], nprocs=int(torch.cuda.device_count()))
|
|
||||||
return return_dict[device1d[0]]
|
|
@ -1,13 +1,32 @@
|
|||||||
import pytest
|
from functools import partial
|
||||||
|
|
||||||
from colossalai.device import profile_alpha_beta
|
import pytest
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from colossalai.device import AlphaBetaProfiler
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
|
def check_alpha_beta(rank, physical_devices, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
profiler = AlphaBetaProfiler(physical_devices)
|
||||||
|
ab_dict = profiler.profile_ab()
|
||||||
|
for _, (alpha, beta) in ab_dict.items():
|
||||||
|
assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Skip because assertion fails for CI devices")
|
@pytest.mark.skip(reason="Skip because assertion fails for CI devices")
|
||||||
def test_profile_alpha_beta():
|
@pytest.mark.dist
|
||||||
physical_devices = [0, 1, 2, 3]
|
@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]])
|
||||||
(alpha, beta) = profile_alpha_beta(physical_devices)
|
@rerun_if_address_is_in_use()
|
||||||
assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10
|
def test_profile_alpha_beta(physical_devices):
|
||||||
|
world_size = 4
|
||||||
|
run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user