mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
fix typo with colossalai/trainer utils zero (#3908)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
# adpated from torch.utils.data.DistributedSampler
|
||||
# adapted from torch.utils.data.DistributedSampler
|
||||
|
||||
import math
|
||||
import random
|
||||
|
@@ -70,7 +70,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||
cls.__init__ = preprocess_after(cls.__init__)
|
||||
|
||||
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
||||
# Excution self._post_init_method after the default init function.
|
||||
# Execution self._post_init_method after the default init function.
|
||||
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set())
|
||||
|
||||
# holding on to the current __init__subclass__ for exit
|
||||
|
@@ -111,7 +111,7 @@ class CommProfiler(BaseProfiler):
|
||||
res.append(sep)
|
||||
|
||||
if self.warn_flag:
|
||||
append("Warnning: there exists multiple communication operations in the same time. As a result, "
|
||||
append("Warning: there exists multiple communication operations in the same time. As a result, "
|
||||
"the profiling result is not accurate.")
|
||||
|
||||
if self.total_cuda_time == 0:
|
||||
@@ -123,12 +123,12 @@ class CommProfiler(BaseProfiler):
|
||||
append("total number of calls: {}".format(self.total_count))
|
||||
append("All events:")
|
||||
|
||||
seperation = '-' * 74
|
||||
separation = '-' * 74
|
||||
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
|
||||
|
||||
append(seperation)
|
||||
append(separation)
|
||||
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
|
||||
append(seperation)
|
||||
append(separation)
|
||||
|
||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
|
||||
for location, event in show_list:
|
||||
|
@@ -130,12 +130,12 @@ class PcieProfiler(BaseProfiler):
|
||||
|
||||
append("Possible data transmission events in PCIE:")
|
||||
|
||||
seperation = '-' * 62
|
||||
separation = '-' * 62
|
||||
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
|
||||
|
||||
append(seperation)
|
||||
append(separation)
|
||||
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
|
||||
append(seperation)
|
||||
append(separation)
|
||||
|
||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
|
||||
for location, event in show_list:
|
||||
|
@@ -32,9 +32,9 @@ def _format_memory(nbytes):
|
||||
return str(nbytes) + ' B'
|
||||
|
||||
|
||||
def _format_bandwidth(volme: float or int, time_us: int):
|
||||
def _format_bandwidth(volume: float or int, time_us: int):
|
||||
sec_div_mb = (1000.0 / 1024.0)**2
|
||||
mb_per_sec = volme / time_us * sec_div_mb
|
||||
mb_per_sec = volume / time_us * sec_div_mb
|
||||
|
||||
if mb_per_sec >= 1024.0:
|
||||
return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Rank Recorder
|
||||
This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily.
|
||||
This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualize the json file easily.
|
||||
|
||||
Before using the tool, you should ensure dist.is_initialized() return true before exit of program.
|
||||
|
||||
@@ -20,7 +20,7 @@ with recorder(record_name, current_rank) as r:
|
||||
```
|
||||
|
||||
## Example
|
||||
This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank.
|
||||
This is a demo to display kernel select in cuda and visualize the cost of several procedures in each rank.
|
||||
|
||||
```python
|
||||
import time
|
||||
|
@@ -133,7 +133,7 @@ class Recorder:
|
||||
with open(self.export_name + '.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(recoders, f, ensure_ascii=False)
|
||||
|
||||
def visualise_record(self):
|
||||
def visualize_record(self):
|
||||
with open(self.export_name + '.json', 'r', encoding='utf-8') as f:
|
||||
records = json.load(f)
|
||||
records = dict(records)
|
||||
@@ -171,7 +171,7 @@ class Recorder:
|
||||
if rank == 1:
|
||||
# take the base time of rank 0 as standard
|
||||
self.merge_recode()
|
||||
self.visualise_record()
|
||||
self.visualize_record()
|
||||
|
||||
|
||||
recorder = Recorder()
|
||||
|
Reference in New Issue
Block a user