polish utils docstring (#620)

This commit is contained in:
ver217 2022-04-01 16:36:47 +08:00 committed by GitHub
parent e619a651fb
commit 369a288bf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 48 deletions

View File

@ -175,7 +175,7 @@ def load_checkpoint(checkpoint_path: str,
If strict is True, then the keys of state_dict must exactly match the keys returned If strict is True, then the keys of state_dict must exactly match the keys returned
by this modules state_dict() function. by this modules state_dict() function.
Args: Args:
checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict. checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict.
model (:class:`torch.nn.Module`): Model to reload parameters and buffers. model (:class:`torch.nn.Module`): Model to reload parameters and buffers.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate. optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate.

View File

@ -11,32 +11,31 @@ from colossalai.utils import get_current_device
class AsyncMemoryMonitor: class AsyncMemoryMonitor:
""" """
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of 1/(10**power) sec. at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar The idea comes from Runtime Memory Tracer of PatrickStar
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
https://arxiv.org/abs/2108.05818
:param power: the power of time interval, defaults to 10
:type power: int
Usage: Usage::
::
```python async_mem_monitor = AsyncMemoryMonitor()
async_mem_monitor = AsyncMemoryMonitor() input = torch.randn(2, 20).cuda()
input = torch.randn(2, 20).cuda() OP1 = torch.nn.Linear(20, 30).cuda()
OP1 = torch.nn.Linear(20, 30).cuda() OP2 = torch.nn.Linear(30, 40).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start() async_mem_monitor.start()
output = OP1(input) output = OP1(input)
async_mem_monitor.finish() async_mem_monitor.finish()
async_mem_monitor.start() async_mem_monitor.start()
output = OP2(output) output = OP2(output)
async_mem_monitor.finish() async_mem_monitor.finish()
async_mem_monitor.save('log.pkl') async_mem_monitor.save('log.pkl')
```
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
""" """
def __init__(self, power: int = 10): def __init__(self, power: int = 10):

View File

@ -8,10 +8,12 @@ from colossalai.utils.profiler import BaseProfiler
class MemProfiler(BaseProfiler): class MemProfiler(BaseProfiler):
"""Wraper of MemOpHook, used to show GPU memory usage through each iteration """Wraper of MemOpHook, used to show GPU memory usage through each iteration
To use this profiler, you need to pass an `engine` instance. And the usage is same like To use this profiler, you need to pass an `engine` instance. And the usage is same like
CommProfiler. CommProfiler.
Usage::
mm_prof = MemProfiler(engine) mm_prof = MemProfiler(engine)
with ProfilerContext([mm_prof]) as prof: with ProfilerContext([mm_prof]) as prof:
writer = SummaryWriter("mem") writer = SummaryWriter("mem")
@ -36,15 +38,11 @@ class MemProfiler(BaseProfiler):
def to_tensorboard(self, writer: SummaryWriter) -> None: def to_tensorboard(self, writer: SummaryWriter) -> None:
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats'] stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
for info, i in enumerate(stats): for info, i in enumerate(stats):
writer.add_scalar( writer.add_scalar("memory_usage/GPU", info, i)
"memory_usage/GPU",
info,
i
)
def to_file(self, data_file: Path) -> None: def to_file(self, data_file: Path) -> None:
self._mem_tracer.save_results(data_file) self._mem_tracer.save_results(data_file)
def show(self) -> None: def show(self) -> None:
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats'] stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
print(stats) print(stats)

View File

@ -70,29 +70,26 @@ class BaseProfiler(ABC):
class ProfilerContext(object): class ProfilerContext(object):
""" """Profiler context manager
Profiler context manager
Usage:
::
```python Usage::
world_size = 4
inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device())
outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device())
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
cc_prof = CommProfiler() world_size = 4
inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device())
outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device())
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
with ProfilerContext([cc_prof]) as prof: cc_prof = CommProfiler()
op = dist.all_reduce(inputs, async_op=True)
dist.all_gather(outputs_list, inputs)
op.wait()
dist.reduce_scatter(inputs, outputs_list)
dist.broadcast(inputs, 0)
dist.reduce(inputs, 0)
prof.show() with ProfilerContext([cc_prof]) as prof:
``` op = dist.all_reduce(inputs, async_op=True)
dist.all_gather(outputs_list, inputs)
op.wait()
dist.reduce_scatter(inputs, outputs_list)
dist.broadcast(inputs, 0)
dist.reduce(inputs, 0)
prof.show()
""" """
def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):