fix typo with colossalai/trainer utils zero (#3908)

This commit is contained in:
digger yu 2023-06-07 16:08:37 +08:00 committed by GitHub
parent b306cecf28
commit a9d1cadc49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 28 additions and 28 deletions

View File

@ -31,9 +31,9 @@ class Trainer:
>>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler >>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler
>>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion) >>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion)
>>> # Beginning training progress >>> # Beginning training progress
>>> timier = ... >>> timer = ...
>>> logger = ... >>> logger = ...
>>> trainer = Trainer(engine=engine, logger=logger, timer=timier) >>> trainer = Trainer(engine=engine, logger=logger, timer=timer)
>>> # add hooks you would like to use here. >>> # add hooks you would like to use here.
>>> hook_list = [] >>> hook_list = []
>>> trainer.fit( >>> trainer.fit(
@ -56,7 +56,7 @@ class Trainer:
timer: MultiTimer = None, timer: MultiTimer = None,
logger: DistributedLogger = None, logger: DistributedLogger = None,
): ):
# training-ralated params # training-related params
self._engine = engine self._engine = engine
self._max_epochs = 0 self._max_epochs = 0
self._cur_epoch = 0 self._cur_epoch = 0
@ -118,7 +118,7 @@ class Trainer:
self._cur_step = epoch * self._steps_per_epoch self._cur_step = epoch * self._steps_per_epoch
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None: def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
"""Call timer funciton with a given timer name. """Call timer function with a given timer name.
Args: Args:
action (str): Function to be called on timer. action (str): Function to be called on timer.

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
# adpated from torch.utils.data.DistributedSampler # adapted from torch.utils.data.DistributedSampler
import math import math
import random import random

View File

@ -70,7 +70,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
cls.__init__ = preprocess_after(cls.__init__) cls.__init__ = preprocess_after(cls.__init__)
# Replace .__init__() for all existing subclasses of torch.nn.Module # Replace .__init__() for all existing subclasses of torch.nn.Module
# Excution self._post_init_method after the default init function. # Execution self._post_init_method after the default init function.
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set()) substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set())
# holding on to the current __init__subclass__ for exit # holding on to the current __init__subclass__ for exit

View File

@ -111,7 +111,7 @@ class CommProfiler(BaseProfiler):
res.append(sep) res.append(sep)
if self.warn_flag: if self.warn_flag:
append("Warnning: there exists multiple communication operations in the same time. As a result, " append("Warning: there exists multiple communication operations in the same time. As a result, "
"the profiling result is not accurate.") "the profiling result is not accurate.")
if self.total_cuda_time == 0: if self.total_cuda_time == 0:
@ -123,12 +123,12 @@ class CommProfiler(BaseProfiler):
append("total number of calls: {}".format(self.total_count)) append("total number of calls: {}".format(self.total_count))
append("All events:") append("All events:")
seperation = '-' * 74 separation = '-' * 74
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2 row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
append(seperation) append(separation)
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls')) append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
append(seperation) append(separation)
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
for location, event in show_list: for location, event in show_list:

View File

@ -130,12 +130,12 @@ class PcieProfiler(BaseProfiler):
append("Possible data transmission events in PCIE:") append("Possible data transmission events in PCIE:")
seperation = '-' * 62 separation = '-' * 62
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2 row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
append(seperation) append(separation)
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls')) append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
append(seperation) append(separation)
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
for location, event in show_list: for location, event in show_list:

View File

@ -32,9 +32,9 @@ def _format_memory(nbytes):
return str(nbytes) + ' B' return str(nbytes) + ' B'
def _format_bandwidth(volme: float or int, time_us: int): def _format_bandwidth(volume: float or int, time_us: int):
sec_div_mb = (1000.0 / 1024.0)**2 sec_div_mb = (1000.0 / 1024.0)**2
mb_per_sec = volme / time_us * sec_div_mb mb_per_sec = volume / time_us * sec_div_mb
if mb_per_sec >= 1024.0: if mb_per_sec >= 1024.0:
return '{:.3f} GB/s'.format(mb_per_sec / 1024.0) return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)

View File

@ -1,5 +1,5 @@
# Rank Recorder # Rank Recorder
This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily. This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualize the json file easily.
Before using the tool, you should ensure dist.is_initialized() return true before exit of program. Before using the tool, you should ensure dist.is_initialized() return true before exit of program.
@ -20,7 +20,7 @@ with recorder(record_name, current_rank) as r:
``` ```
## Example ## Example
This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank. This is a demo to display kernel select in cuda and visualize the cost of several procedures in each rank.
```python ```python
import time import time

View File

@ -133,7 +133,7 @@ class Recorder:
with open(self.export_name + '.json', 'w', encoding='utf-8') as f: with open(self.export_name + '.json', 'w', encoding='utf-8') as f:
json.dump(recoders, f, ensure_ascii=False) json.dump(recoders, f, ensure_ascii=False)
def visualise_record(self): def visualize_record(self):
with open(self.export_name + '.json', 'r', encoding='utf-8') as f: with open(self.export_name + '.json', 'r', encoding='utf-8') as f:
records = json.load(f) records = json.load(f)
records = dict(records) records = dict(records)
@ -171,7 +171,7 @@ class Recorder:
if rank == 1: if rank == 1:
# take the base time of rank 0 as standard # take the base time of rank 0 as standard
self.merge_recode() self.merge_recode()
self.visualise_record() self.visualize_record()
recorder = Recorder() recorder = Recorder()

View File

@ -416,7 +416,7 @@ class Chunk:
Copy data slice to the memory space indexed by the input tensor in the chunk. Copy data slice to the memory space indexed by the input tensor in the chunk.
Args: Args:
tensor (torch.Tensor): the tensor used to retrive meta information tensor (torch.Tensor): the tensor used to retrieve meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk data_slice (torch.Tensor): the tensor to be copied to the chunk
""" """
# sanity check # sanity check

View File

@ -157,7 +157,7 @@ class ChunkManager:
Copy data to the chunk. Copy data to the chunk.
Args: Args:
tensor (torch.Tensor): the tensor used to retrive meta information tensor (torch.Tensor): the tensor used to retrieve meta information
data (torch.Tensor): the tensor to be copied to the chunk data (torch.Tensor): the tensor to be copied to the chunk
""" """
chunk = self.tensor_chunk_map[tensor] chunk = self.tensor_chunk_map[tensor]

View File

@ -25,7 +25,7 @@ class ChunkMemStatsCollector(MemStatsCollector):
# override # override
def record_model_data_volume(self) -> None: def record_model_data_volume(self) -> None:
""" """
record model data volumn on cuda and cpu. record model data volume on cuda and cpu.
""" """
if self._start_flag and not self.use_outside_memstats: if self._start_flag and not self.use_outside_memstats:
cuda_mem = self._chunk_manager.total_mem['cuda'] cuda_mem = self._chunk_manager.total_mem['cuda']

View File

@ -45,7 +45,7 @@ class MemoryMonitor:
class AsyncMemoryMonitor(MemoryMonitor): class AsyncMemoryMonitor(MemoryMonitor):
""" """
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU An Async Memory Monitor running during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec. at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar The idea comes from Runtime Memory Tracer of PatrickStar
@ -67,7 +67,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
async_mem_monitor.save('log.pkl') async_mem_monitor.save('log.pkl')
Args: Args:
power (int, optional): the power of time interva. Defaults to 10. power (int, optional): the power of time interval. Defaults to 10.
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management: .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818 https://arxiv.org/abs/2108.05818

View File

@ -73,7 +73,7 @@ def get_static_torch_model(zero_ddp_model,
zero_ddp_model (ZeroDDP): a zero ddp model zero_ddp_model (ZeroDDP): a zero ddp model
device (torch.device): the device of the final torch model device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the coverted torch model only_rank_0 (bool): if True, only rank0 has the converted torch model
Returns: Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks torch.nn.Module: a static torch model used for saving checkpoints or numeric checks

View File

@ -88,7 +88,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook], ophook_list: List[BaseOpHook],
name: str = "", name: str = "",
filter_fn: Optional[Callable] = None): filter_fn: Optional[Callable] = None):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module) assert isinstance(module, torch.nn.Module)
assert isinstance(ophook_list, (list, tuple)) assert isinstance(ophook_list, (list, tuple))
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0' assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
@ -103,7 +103,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
if len(list(module.parameters(recurse=False))) == 0: if len(list(module.parameters(recurse=False))) == 0:
return return
# return from flitered module # return from filtered module
if filter_fn is not None and filter_fn(module): if filter_fn is not None and filter_fn(module):
return return

View File

@ -77,7 +77,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
move a tensor to the target_device move a tensor to the target_device
Args: Args:
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
target_device: a traget device, if type is int, it the index of cuda card. target_device: a target device, if type is int, it the index of cuda card.
""" """
if not isinstance(target_device, torch.device): if not isinstance(target_device, torch.device):
target_device = torch.device(f'cuda:{target_device}') target_device = torch.device(f'cuda:{target_device}')