mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
fix typo with colossalai/trainer utils zero (#3908)
This commit is contained in:
parent
b306cecf28
commit
a9d1cadc49
@ -31,9 +31,9 @@ class Trainer:
|
|||||||
>>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler
|
>>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler
|
||||||
>>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion)
|
>>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion)
|
||||||
>>> # Beginning training progress
|
>>> # Beginning training progress
|
||||||
>>> timier = ...
|
>>> timer = ...
|
||||||
>>> logger = ...
|
>>> logger = ...
|
||||||
>>> trainer = Trainer(engine=engine, logger=logger, timer=timier)
|
>>> trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||||
>>> # add hooks you would like to use here.
|
>>> # add hooks you would like to use here.
|
||||||
>>> hook_list = []
|
>>> hook_list = []
|
||||||
>>> trainer.fit(
|
>>> trainer.fit(
|
||||||
@ -56,7 +56,7 @@ class Trainer:
|
|||||||
timer: MultiTimer = None,
|
timer: MultiTimer = None,
|
||||||
logger: DistributedLogger = None,
|
logger: DistributedLogger = None,
|
||||||
):
|
):
|
||||||
# training-ralated params
|
# training-related params
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
self._max_epochs = 0
|
self._max_epochs = 0
|
||||||
self._cur_epoch = 0
|
self._cur_epoch = 0
|
||||||
@ -118,7 +118,7 @@ class Trainer:
|
|||||||
self._cur_step = epoch * self._steps_per_epoch
|
self._cur_step = epoch * self._steps_per_epoch
|
||||||
|
|
||||||
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
|
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
|
||||||
"""Call timer funciton with a given timer name.
|
"""Call timer function with a given timer name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action (str): Function to be called on timer.
|
action (str): Function to be called on timer.
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
# adpated from torch.utils.data.DistributedSampler
|
# adapted from torch.utils.data.DistributedSampler
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
@ -70,7 +70,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||||||
cls.__init__ = preprocess_after(cls.__init__)
|
cls.__init__ = preprocess_after(cls.__init__)
|
||||||
|
|
||||||
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
||||||
# Excution self._post_init_method after the default init function.
|
# Execution self._post_init_method after the default init function.
|
||||||
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set())
|
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set())
|
||||||
|
|
||||||
# holding on to the current __init__subclass__ for exit
|
# holding on to the current __init__subclass__ for exit
|
||||||
|
@ -111,7 +111,7 @@ class CommProfiler(BaseProfiler):
|
|||||||
res.append(sep)
|
res.append(sep)
|
||||||
|
|
||||||
if self.warn_flag:
|
if self.warn_flag:
|
||||||
append("Warnning: there exists multiple communication operations in the same time. As a result, "
|
append("Warning: there exists multiple communication operations in the same time. As a result, "
|
||||||
"the profiling result is not accurate.")
|
"the profiling result is not accurate.")
|
||||||
|
|
||||||
if self.total_cuda_time == 0:
|
if self.total_cuda_time == 0:
|
||||||
@ -123,12 +123,12 @@ class CommProfiler(BaseProfiler):
|
|||||||
append("total number of calls: {}".format(self.total_count))
|
append("total number of calls: {}".format(self.total_count))
|
||||||
append("All events:")
|
append("All events:")
|
||||||
|
|
||||||
seperation = '-' * 74
|
separation = '-' * 74
|
||||||
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
|
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
|
||||||
|
|
||||||
append(seperation)
|
append(separation)
|
||||||
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
|
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
|
||||||
append(seperation)
|
append(separation)
|
||||||
|
|
||||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
|
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
|
||||||
for location, event in show_list:
|
for location, event in show_list:
|
||||||
|
@ -130,12 +130,12 @@ class PcieProfiler(BaseProfiler):
|
|||||||
|
|
||||||
append("Possible data transmission events in PCIE:")
|
append("Possible data transmission events in PCIE:")
|
||||||
|
|
||||||
seperation = '-' * 62
|
separation = '-' * 62
|
||||||
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
|
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
|
||||||
|
|
||||||
append(seperation)
|
append(separation)
|
||||||
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
|
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
|
||||||
append(seperation)
|
append(separation)
|
||||||
|
|
||||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
|
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
|
||||||
for location, event in show_list:
|
for location, event in show_list:
|
||||||
|
@ -32,9 +32,9 @@ def _format_memory(nbytes):
|
|||||||
return str(nbytes) + ' B'
|
return str(nbytes) + ' B'
|
||||||
|
|
||||||
|
|
||||||
def _format_bandwidth(volme: float or int, time_us: int):
|
def _format_bandwidth(volume: float or int, time_us: int):
|
||||||
sec_div_mb = (1000.0 / 1024.0)**2
|
sec_div_mb = (1000.0 / 1024.0)**2
|
||||||
mb_per_sec = volme / time_us * sec_div_mb
|
mb_per_sec = volume / time_us * sec_div_mb
|
||||||
|
|
||||||
if mb_per_sec >= 1024.0:
|
if mb_per_sec >= 1024.0:
|
||||||
return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
|
return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Rank Recorder
|
# Rank Recorder
|
||||||
This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily.
|
This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualize the json file easily.
|
||||||
|
|
||||||
Before using the tool, you should ensure dist.is_initialized() return true before exit of program.
|
Before using the tool, you should ensure dist.is_initialized() return true before exit of program.
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ with recorder(record_name, current_rank) as r:
|
|||||||
```
|
```
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank.
|
This is a demo to display kernel select in cuda and visualize the cost of several procedures in each rank.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import time
|
import time
|
||||||
|
@ -133,7 +133,7 @@ class Recorder:
|
|||||||
with open(self.export_name + '.json', 'w', encoding='utf-8') as f:
|
with open(self.export_name + '.json', 'w', encoding='utf-8') as f:
|
||||||
json.dump(recoders, f, ensure_ascii=False)
|
json.dump(recoders, f, ensure_ascii=False)
|
||||||
|
|
||||||
def visualise_record(self):
|
def visualize_record(self):
|
||||||
with open(self.export_name + '.json', 'r', encoding='utf-8') as f:
|
with open(self.export_name + '.json', 'r', encoding='utf-8') as f:
|
||||||
records = json.load(f)
|
records = json.load(f)
|
||||||
records = dict(records)
|
records = dict(records)
|
||||||
@ -171,7 +171,7 @@ class Recorder:
|
|||||||
if rank == 1:
|
if rank == 1:
|
||||||
# take the base time of rank 0 as standard
|
# take the base time of rank 0 as standard
|
||||||
self.merge_recode()
|
self.merge_recode()
|
||||||
self.visualise_record()
|
self.visualize_record()
|
||||||
|
|
||||||
|
|
||||||
recorder = Recorder()
|
recorder = Recorder()
|
||||||
|
@ -416,7 +416,7 @@ class Chunk:
|
|||||||
Copy data slice to the memory space indexed by the input tensor in the chunk.
|
Copy data slice to the memory space indexed by the input tensor in the chunk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor (torch.Tensor): the tensor used to retrive meta information
|
tensor (torch.Tensor): the tensor used to retrieve meta information
|
||||||
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
||||||
"""
|
"""
|
||||||
# sanity check
|
# sanity check
|
||||||
|
@ -157,7 +157,7 @@ class ChunkManager:
|
|||||||
Copy data to the chunk.
|
Copy data to the chunk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor (torch.Tensor): the tensor used to retrive meta information
|
tensor (torch.Tensor): the tensor used to retrieve meta information
|
||||||
data (torch.Tensor): the tensor to be copied to the chunk
|
data (torch.Tensor): the tensor to be copied to the chunk
|
||||||
"""
|
"""
|
||||||
chunk = self.tensor_chunk_map[tensor]
|
chunk = self.tensor_chunk_map[tensor]
|
||||||
|
@ -25,7 +25,7 @@ class ChunkMemStatsCollector(MemStatsCollector):
|
|||||||
# override
|
# override
|
||||||
def record_model_data_volume(self) -> None:
|
def record_model_data_volume(self) -> None:
|
||||||
"""
|
"""
|
||||||
record model data volumn on cuda and cpu.
|
record model data volume on cuda and cpu.
|
||||||
"""
|
"""
|
||||||
if self._start_flag and not self.use_outside_memstats:
|
if self._start_flag and not self.use_outside_memstats:
|
||||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||||
|
@ -45,7 +45,7 @@ class MemoryMonitor:
|
|||||||
|
|
||||||
class AsyncMemoryMonitor(MemoryMonitor):
|
class AsyncMemoryMonitor(MemoryMonitor):
|
||||||
"""
|
"""
|
||||||
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
|
An Async Memory Monitor running during computing. Sampling memory usage of the current GPU
|
||||||
at interval of `1/(10**power)` sec.
|
at interval of `1/(10**power)` sec.
|
||||||
|
|
||||||
The idea comes from Runtime Memory Tracer of PatrickStar
|
The idea comes from Runtime Memory Tracer of PatrickStar
|
||||||
@ -67,7 +67,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
|
|||||||
async_mem_monitor.save('log.pkl')
|
async_mem_monitor.save('log.pkl')
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
power (int, optional): the power of time interva. Defaults to 10.
|
power (int, optional): the power of time interval. Defaults to 10.
|
||||||
|
|
||||||
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
|
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
|
||||||
https://arxiv.org/abs/2108.05818
|
https://arxiv.org/abs/2108.05818
|
||||||
|
@ -73,7 +73,7 @@ def get_static_torch_model(zero_ddp_model,
|
|||||||
zero_ddp_model (ZeroDDP): a zero ddp model
|
zero_ddp_model (ZeroDDP): a zero ddp model
|
||||||
device (torch.device): the device of the final torch model
|
device (torch.device): the device of the final torch model
|
||||||
dtype (torch.dtype): the dtype of the final torch model
|
dtype (torch.dtype): the dtype of the final torch model
|
||||||
only_rank_0 (bool): if True, only rank0 has the coverted torch model
|
only_rank_0 (bool): if True, only rank0 has the converted torch model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
||||||
|
@ -88,7 +88,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
|||||||
ophook_list: List[BaseOpHook],
|
ophook_list: List[BaseOpHook],
|
||||||
name: str = "",
|
name: str = "",
|
||||||
filter_fn: Optional[Callable] = None):
|
filter_fn: Optional[Callable] = None):
|
||||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||||
assert isinstance(module, torch.nn.Module)
|
assert isinstance(module, torch.nn.Module)
|
||||||
assert isinstance(ophook_list, (list, tuple))
|
assert isinstance(ophook_list, (list, tuple))
|
||||||
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
|
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
|
||||||
@ -103,7 +103,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
|||||||
if len(list(module.parameters(recurse=False))) == 0:
|
if len(list(module.parameters(recurse=False))) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# return from flitered module
|
# return from filtered module
|
||||||
if filter_fn is not None and filter_fn(module):
|
if filter_fn is not None and filter_fn(module):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
|
|||||||
move a tensor to the target_device
|
move a tensor to the target_device
|
||||||
Args:
|
Args:
|
||||||
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
|
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
|
||||||
target_device: a traget device, if type is int, it the index of cuda card.
|
target_device: a target device, if type is int, it the index of cuda card.
|
||||||
"""
|
"""
|
||||||
if not isinstance(target_device, torch.device):
|
if not isinstance(target_device, torch.device):
|
||||||
target_device = torch.device(f'cuda:{target_device}')
|
target_device = torch.device(f'cuda:{target_device}')
|
||||||
|
Loading…
Reference in New Issue
Block a user