mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
Added Profiler Context to manage all profilers (#340)
This commit is contained in:
@@ -1,40 +0,0 @@
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
import colossalai
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.profiler import enable_communication_prof, communication_prof_show
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
D_MODEL = 1024
|
||||
CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=4)))
|
||||
|
||||
|
||||
def run_test(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
inputs = torch.randn(BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
|
||||
outputs = torch.empty(world_size, BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
|
||||
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
|
||||
|
||||
enable_communication_prof()
|
||||
|
||||
op = dist.all_reduce(inputs, async_op=True)
|
||||
dist.all_gather(outputs_list, inputs)
|
||||
op.wait()
|
||||
dist.reduce_scatter(inputs, outputs_list)
|
||||
dist.broadcast(inputs, 0)
|
||||
dist.reduce(inputs, 0)
|
||||
|
||||
if rank == 0:
|
||||
communication_prof_show()
|
||||
|
||||
|
||||
def test_cc_prof():
|
||||
world_size = 4
|
||||
run_func = partial(run_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cc_prof()
|
Reference in New Issue
Block a user