polish utils docstring (#620)

This commit is contained in:
ver217
2022-04-01 16:36:47 +08:00
committed by GitHub
parent e619a651fb
commit 369a288bf3
4 changed files with 42 additions and 48 deletions

View File

@@ -8,10 +8,12 @@ from colossalai.utils.profiler import BaseProfiler
class MemProfiler(BaseProfiler):
"""Wraper of MemOpHook, used to show GPU memory usage through each iteration
To use this profiler, you need to pass an `engine` instance. And the usage is same like
CommProfiler.
Usage::
mm_prof = MemProfiler(engine)
with ProfilerContext([mm_prof]) as prof:
writer = SummaryWriter("mem")
@@ -36,15 +38,11 @@ class MemProfiler(BaseProfiler):
def to_tensorboard(self, writer: SummaryWriter) -> None:
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
for info, i in enumerate(stats):
writer.add_scalar(
"memory_usage/GPU",
info,
i
)
writer.add_scalar("memory_usage/GPU", info, i)
def to_file(self, data_file: Path) -> None:
self._mem_tracer.save_results(data_file)
def show(self) -> None:
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
print(stats)

View File

@@ -70,29 +70,26 @@ class BaseProfiler(ABC):
class ProfilerContext(object):
"""
Profiler context manager
Usage:
::
"""Profiler context manager
```python
world_size = 4
inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device())
outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device())
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
Usage::
cc_prof = CommProfiler()
world_size = 4
inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device())
outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device())
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
with ProfilerContext([cc_prof]) as prof:
op = dist.all_reduce(inputs, async_op=True)
dist.all_gather(outputs_list, inputs)
op.wait()
dist.reduce_scatter(inputs, outputs_list)
dist.broadcast(inputs, 0)
dist.reduce(inputs, 0)
cc_prof = CommProfiler()
prof.show()
```
with ProfilerContext([cc_prof]) as prof:
op = dist.all_reduce(inputs, async_op=True)
dist.all_gather(outputs_list, inputs)
op.wait()
dist.reduce_scatter(inputs, outputs_list)
dist.broadcast(inputs, 0)
dist.reduce(inputs, 0)
prof.show()
"""
def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):