diff --git a/colossalai/communication/__init__.py b/colossalai/communication/__init__.py index 25e817f1f..b71744d9f 100644 --- a/colossalai/communication/__init__.py +++ b/colossalai/communication/__init__.py @@ -1,17 +1,26 @@ from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce -from .p2p import (send_forward, send_forward_recv_forward, - send_backward_recv_forward, send_backward, - send_backward_recv_backward, send_forward_recv_backward, - send_forward_backward_recv_forward_backward, recv_forward, - recv_backward) +from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward, + send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward, + recv_forward, recv_backward) from .ring import ring_forward from .utils import send_tensor_meta, recv_tensor_meta __all__ = [ - 'all_gather', 'reduce_scatter', 'all_reduce', 'broadcast', 'reduce', - 'send_forward', 'send_forward_recv_forward', - 'send_forward_backward_recv_forward_backward', 'send_backward', - 'send_backward_recv_backward', 'send_backward_recv_forward', - 'send_forward_recv_backward', 'recv_backward', 'recv_forward', - 'ring_forward', 'send_tensor_meta', 'recv_tensor_meta', + 'all_gather', + 'reduce_scatter', + 'all_reduce', + 'broadcast', + 'reduce', + 'send_forward', + 'send_forward_recv_forward', + 'send_forward_backward_recv_forward_backward', + 'send_backward', + 'send_backward_recv_backward', + 'send_backward_recv_forward', + 'send_forward_recv_backward', + 'recv_backward', + 'recv_forward', + 'ring_forward', + 'send_tensor_meta', + 'recv_tensor_meta', ] diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/trainer/hooks/_lr_scheduler_hook.py index 726db3bf5..848d6e331 100644 --- a/colossalai/trainer/hooks/_lr_scheduler_hook.py +++ b/colossalai/trainer/hooks/_lr_scheduler_hook.py @@ -29,6 +29,7 @@ class LRSchedulerHook(MetricHook): self.store_lr_in_state = store_lr_in_state def after_hook_is_attached(self, trainer): + self._check_metric_states_initialization(trainer) trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch, initial_lr=self.lr_scheduler.get_last_lr()[0]) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 645b58131..848937d7c 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -1,12 +1,9 @@ from .activation_checkpoint import checkpoint -from .common import (clip_grad_norm_fp32, conditional_context, - copy_tensor_parallel_attributes, count_zeros_fp32, - free_port, is_dp_rank_0, is_model_parallel_parameter, - is_moe_parallel_parameter, is_no_pp_or_last_stage, - is_tp_rank_0, is_using_ddp, is_using_pp, - is_using_sequence, multi_tensor_applier, - param_is_not_tensor_parallel_duplicate, print_rank_0, +from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, + free_port, is_dp_rank_0, is_model_parallel_parameter, is_moe_parallel_parameter, + is_no_pp_or_last_stage, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, + multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param) from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .data_sampler import DataParallelSampler, get_dataloader diff --git a/colossalai/utils/profiler/__init__.py b/colossalai/utils/profiler/__init__.py new file mode 100644 index 000000000..9b216c437 --- /dev/null +++ b/colossalai/utils/profiler/__init__.py @@ -0,0 +1 @@ +from .comm_profiler import enable_communication_prof, communication_prof_show diff --git a/colossalai/utils/profiler/comm_profiler.py b/colossalai/utils/profiler/comm_profiler.py new file mode 100644 index 000000000..8413597c3 --- /dev/null +++ b/colossalai/utils/profiler/comm_profiler.py @@ -0,0 +1,302 @@ +import inspect +import torch +from torch.autograd.profiler import profile +import torch.distributed as dist +from torch.distributed import ReduceOp +from colossalai.utils import get_current_device +from typing import List, Optional + + +def _get_code_location(depth: int): + ret = "" + length = len(inspect.stack()) + for i in range(3, min(length, depth + 1)): + upper_frame = inspect.stack()[i] + function_name = inspect.stack()[i - 1].function + info = upper_frame.filename + "(" + str(upper_frame.lineno) + "): " + function_name + "\n" + ret += info + + return ret + + +# copied from high version pytorch to support low version +def _format_time(time_us): + """Defines how to format time in FunctionEvent""" + US_IN_SECOND = 1000.0 * 1000.0 + US_IN_MS = 1000.0 + if time_us >= US_IN_SECOND: + return '{:.3f}s'.format(time_us / US_IN_SECOND) + if time_us >= US_IN_MS: + return '{:.3f}ms'.format(time_us / US_IN_MS) + return '{:.3f}us'.format(time_us) + + +# copied from high version pytorch to support low version +def _format_memory(nbytes): + """Returns a formatted memory size string""" + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + if (abs(nbytes) >= GB): + return '{:.2f} Gb'.format(nbytes * 1.0 / GB) + elif (abs(nbytes) >= MB): + return '{:.2f} Mb'.format(nbytes * 1.0 / MB) + elif (abs(nbytes) >= KB): + return '{:.2f} Kb'.format(nbytes * 1.0 / KB) + else: + return str(nbytes) + ' b' + + +def _format_bandwith(volme: float, time_us: int): + sec_div_mb = (1000.0 / 1024.0)**2 + mb_per_sec = volme / time_us * sec_div_mb + + if mb_per_sec >= 1024.0: + return '{:.3f} Gb/s'.format(mb_per_sec / 1024.0) + else: + return '{:.3f} Mb/s'.format(mb_per_sec) + + +class CommEvent(object): + """Communication Event. Used for communication time and communication + volume recording. + """ + + def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0): + self.self_count = count + self.self_comm_vol = comm_vol + self.self_cuda_time = cuda_time + + def add(self, rhs): + self.self_count += rhs.self_count + self.self_comm_vol += rhs.self_comm_vol + self.self_cuda_time += rhs.self_cuda_time + + +class CommProfiler(object): + """Communication profiler. Records all communication events. + """ + + def __init__(self, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0, prof_depth: int = 3): + super().__init__() + self.total_count = total_count + self.total_comm_vol = total_comm_vol + self.total_cuda_time = total_cuda_time + self.depth = prof_depth + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def reset(self): + self.total_count = 0 + self.total_comm_vol = 0 + self.total_cuda_time = 0 + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def show(self): + if self.warn_flag: + print("Warnning: there exists multiple communication operations in the same time.\n" + "As a result, the profiling result is not accurate.") + print("Collective communication profiling result:", + "total cuda time: {}".format(_format_time(self.total_cuda_time)), + "average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)), + "total number of calls: {}".format(self.total_count), + "All events:", + sep='\n') + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) + for location, event in show_list: + print(location, + "self cuda time: {}".format(_format_time(event.self_cuda_time)), + "{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0), + "self communication volme: {}".format(_format_memory(event.self_comm_vol)), + "average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)), + "number of calls: {}".format(event.self_count), + "--------------------", + sep='\n') + + @property + def has_aync_op(self): + return self.pending_op is not None + + def activate_profiler(self, kn: str, vol: float): + self.pending_metadata = (kn, _get_code_location(self.depth), vol) + self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True) + self.profiler.__enter__() + + def close_profiler(self, group=None): + assert self.profiler is not None, "There is no running dist op" + kernel_name, code_location, vol = self.pending_metadata + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled: + assert_flag = 0 + current_comm_event = None + events = self.profiler.function_events + for event in events: + if kernel_name in event.name: + assert assert_flag == 0, "Multiple dist ops has been called " + current_comm_event = CommEvent(1, vol, event.self_cuda_time_total) + assert_flag += 1 + + assert current_comm_event is not None, "dist op has not been found" + + buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) + torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) + current_comm_event.self_cuda_time = buffer.item() + + self.total_count += current_comm_event.self_count + self.total_comm_vol += current_comm_event.self_comm_vol + self.total_cuda_time += current_comm_event.self_cuda_time + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + + self.profiler = None + self.pending_op = None + self.pending_metadata = None + + def wait_async_op(self): + if self.pending_op is not None: + op = self.pending_op + op.wait() + self.close_profiler() + + +class CommHandler(object): + """Communication handler. A dummy handler to wait aync operations. + """ + + def __init__(self): + super().__init__() + self.prof = COL_COMM_PROF + + def wait(self): + self.prof.wait_async_op() + + +COL_COMM_PROF = CommProfiler() +torch_all_reduce = dist.all_reduce +torch_all_gather = dist.all_gather +torch_reduce_scatter = dist.reduce_scatter +torch_broadcast = dist.broadcast +torch_reduce = dist.reduce + + +def enable_communication_prof(depth: int = 0): + COL_COMM_PROF.depth = 3 + depth + dist.all_reduce = all_reduce + dist.all_gather = all_gather + dist.reduce_scatter = reduce_scatter + dist.broadcast = broadcast + dist.reduce = reduce + + +def communication_prof_show(): + COL_COMM_PROF.show() + + +def async_check(): + if COL_COMM_PROF.pending_op is not None: + COL_COMM_PROF.warn_flag = True + COL_COMM_PROF.wait_async_op() + + +def all_reduce(tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False) -> Optional[CommHandler]: + async_check() + + comm_size = dist.get_world_size(group) + correction = 2 * (comm_size - 1) / comm_size + comm_vol = correction * tensor.element_size() * tensor.numel() + COL_COMM_PROF.activate_profiler("ncclKernel_AllReduce_", comm_vol) + COL_COMM_PROF.pending_op = torch_all_reduce(tensor, op, group, async_op) + + if async_op: + return CommHandler() + + COL_COMM_PROF.close_profiler(group) + + +def reduce_scatter(output: torch.Tensor, + input_list: List[torch.Tensor], + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False) -> Optional[CommHandler]: + async_check() + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for tensor in input_list: + comm_vol += tensor.element_size() * tensor.numel() + comm_vol *= correction + COL_COMM_PROF.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) + COL_COMM_PROF.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) + + if async_op: + return CommHandler() + + COL_COMM_PROF.close_profiler(group) + + +def all_gather(tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op: bool = False) -> Optional[CommHandler]: + async_check() + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for ten in tensor_list: + comm_vol += ten.element_size() * ten.numel() + comm_vol *= correction + COL_COMM_PROF.activate_profiler("ncclKernel_AllGather_", comm_vol) + COL_COMM_PROF.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) + + if async_op: + return CommHandler() + + COL_COMM_PROF.close_profiler(group) + + +def broadcast(tensor: torch.Tensor, src: int, group=None, async_op: bool = False) -> Optional[CommHandler]: + async_check() + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + COL_COMM_PROF.activate_profiler("ncclKernel_Broadcast_", comm_vol) + COL_COMM_PROF.pending_op = torch_broadcast(tensor, src, group, async_op) + + if async_op: + return CommHandler() + + COL_COMM_PROF.close_profiler(group) + + +def reduce(tensor: torch.Tensor, + dst: int, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False) -> Optional[CommHandler]: + async_check() + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + COL_COMM_PROF.activate_profiler("ncclKernel_Reduce_", comm_vol) + COL_COMM_PROF.pending_op = torch_reduce(tensor, dst, op, group, async_op) + + if async_op: + return CommHandler() + + COL_COMM_PROF.close_profiler(group) diff --git a/tests/test_profiler/test_comm_prof.py b/tests/test_profiler/test_comm_prof.py new file mode 100644 index 000000000..133ea0262 --- /dev/null +++ b/tests/test_profiler/test_comm_prof.py @@ -0,0 +1,40 @@ +from functools import partial +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +import colossalai +from colossalai.utils import free_port, get_current_device +from colossalai.utils.profiler import enable_communication_prof, communication_prof_show + +BATCH_SIZE = 1024 +D_MODEL = 1024 +CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=4))) + + +def run_test(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + inputs = torch.randn(BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device()) + outputs = torch.empty(world_size, BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device()) + outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) + + enable_communication_prof() + + op = dist.all_reduce(inputs, async_op=True) + dist.all_gather(outputs_list, inputs) + op.wait() + dist.reduce_scatter(inputs, outputs_list) + dist.broadcast(inputs, 0) + dist.reduce(inputs, 0) + + if rank == 0: + communication_prof_show() + + +def test_cc_prof(): + world_size = 4 + run_func = partial(run_test, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_cc_prof()