diff --git a/colossalai/utils/checkpointing.py b/colossalai/utils/checkpointing.py index be13b18f4..e822a8bfd 100644 --- a/colossalai/utils/checkpointing.py +++ b/colossalai/utils/checkpointing.py @@ -175,7 +175,7 @@ def load_checkpoint(checkpoint_path: str, If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function. - Args: + Args: checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict. model (:class:`torch.nn.Module`): Model to reload parameters and buffers. optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate. diff --git a/colossalai/utils/memory_tracer/async_memtracer.py b/colossalai/utils/memory_tracer/async_memtracer.py index 2c3ca0005..7ceaf9d80 100644 --- a/colossalai/utils/memory_tracer/async_memtracer.py +++ b/colossalai/utils/memory_tracer/async_memtracer.py @@ -11,32 +11,31 @@ from colossalai.utils import get_current_device class AsyncMemoryMonitor: """ An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU - at interval of 1/(10**power) sec. + at interval of `1/(10**power)` sec. The idea comes from Runtime Memory Tracer of PatrickStar - PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management - https://arxiv.org/abs/2108.05818 - - :param power: the power of time interval, defaults to 10 - :type power: int + `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ - Usage: - :: + Usage:: - ```python - async_mem_monitor = AsyncMemoryMonitor() - input = torch.randn(2, 20).cuda() - OP1 = torch.nn.Linear(20, 30).cuda() - OP2 = torch.nn.Linear(30, 40).cuda() + async_mem_monitor = AsyncMemoryMonitor() + input = torch.randn(2, 20).cuda() + OP1 = torch.nn.Linear(20, 30).cuda() + OP2 = torch.nn.Linear(30, 40).cuda() - async_mem_monitor.start() - output = OP1(input) - async_mem_monitor.finish() - async_mem_monitor.start() - output = OP2(output) - async_mem_monitor.finish() - async_mem_monitor.save('log.pkl') - ``` + async_mem_monitor.start() + output = OP1(input) + async_mem_monitor.finish() + async_mem_monitor.start() + output = OP2(output) + async_mem_monitor.finish() + async_mem_monitor.save('log.pkl') + + Args: + power (int, optional): the power of time interva. Defaults to 10. + + .. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management: + https://arxiv.org/abs/2108.05818 """ def __init__(self, power: int = 10): diff --git a/colossalai/utils/profiler/mem_profiler.py b/colossalai/utils/profiler/mem_profiler.py index b60e714c4..662417dfd 100644 --- a/colossalai/utils/profiler/mem_profiler.py +++ b/colossalai/utils/profiler/mem_profiler.py @@ -8,10 +8,12 @@ from colossalai.utils.profiler import BaseProfiler class MemProfiler(BaseProfiler): """Wraper of MemOpHook, used to show GPU memory usage through each iteration - + To use this profiler, you need to pass an `engine` instance. And the usage is same like CommProfiler. + Usage:: + mm_prof = MemProfiler(engine) with ProfilerContext([mm_prof]) as prof: writer = SummaryWriter("mem") @@ -36,15 +38,11 @@ class MemProfiler(BaseProfiler): def to_tensorboard(self, writer: SummaryWriter) -> None: stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats'] for info, i in enumerate(stats): - writer.add_scalar( - "memory_usage/GPU", - info, - i - ) + writer.add_scalar("memory_usage/GPU", info, i) def to_file(self, data_file: Path) -> None: self._mem_tracer.save_results(data_file) def show(self) -> None: - stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats'] + stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats'] print(stats) diff --git a/colossalai/utils/profiler/prof_utils.py b/colossalai/utils/profiler/prof_utils.py index a7e35bc42..87ad644a7 100644 --- a/colossalai/utils/profiler/prof_utils.py +++ b/colossalai/utils/profiler/prof_utils.py @@ -70,29 +70,26 @@ class BaseProfiler(ABC): class ProfilerContext(object): - """ - Profiler context manager - Usage: - :: + """Profiler context manager - ```python - world_size = 4 - inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) - outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) - outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) + Usage:: - cc_prof = CommProfiler() + world_size = 4 + inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) + outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) + outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) - with ProfilerContext([cc_prof]) as prof: - op = dist.all_reduce(inputs, async_op=True) - dist.all_gather(outputs_list, inputs) - op.wait() - dist.reduce_scatter(inputs, outputs_list) - dist.broadcast(inputs, 0) - dist.reduce(inputs, 0) + cc_prof = CommProfiler() - prof.show() - ``` + with ProfilerContext([cc_prof]) as prof: + op = dist.all_reduce(inputs, async_op=True) + dist.all_gather(outputs_list, inputs) + op.wait() + dist.reduce_scatter(inputs, outputs_list) + dist.broadcast(inputs, 0) + dist.reduce(inputs, 0) + + prof.show() """ def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):