mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -120,26 +120,30 @@ class profile(torch_profile):
|
||||
p.step()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
activities: Optional[Iterable[ProfilerActivity]] = None,
|
||||
schedule: Optional[Callable[[int], ProfilerAction]] = None,
|
||||
on_trace_ready: Optional[Callable[..., Any]] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
record_shapes: bool = False,
|
||||
profile_memory: bool = False,
|
||||
with_stack: bool = False,
|
||||
with_flops: bool = False,
|
||||
with_modules: bool = False,
|
||||
profile_stateful_tensor_memory: bool = False) -> None:
|
||||
super().__init__(activities=activities,
|
||||
schedule=schedule,
|
||||
on_trace_ready=on_trace_ready,
|
||||
record_shapes=record_shapes,
|
||||
profile_memory=profile_memory,
|
||||
with_stack=with_stack,
|
||||
with_flops=with_flops,
|
||||
with_modules=with_modules)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
activities: Optional[Iterable[ProfilerActivity]] = None,
|
||||
schedule: Optional[Callable[[int], ProfilerAction]] = None,
|
||||
on_trace_ready: Optional[Callable[..., Any]] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
record_shapes: bool = False,
|
||||
profile_memory: bool = False,
|
||||
with_stack: bool = False,
|
||||
with_flops: bool = False,
|
||||
with_modules: bool = False,
|
||||
profile_stateful_tensor_memory: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
activities=activities,
|
||||
schedule=schedule,
|
||||
on_trace_ready=on_trace_ready,
|
||||
record_shapes=record_shapes,
|
||||
profile_memory=profile_memory,
|
||||
with_stack=with_stack,
|
||||
with_flops=with_flops,
|
||||
with_modules=with_modules,
|
||||
)
|
||||
self._logger = get_dist_logger()
|
||||
self.extentions: List[ProfilerExtension] = []
|
||||
if profile_stateful_tensor_memory:
|
||||
@@ -149,9 +153,9 @@ class profile(torch_profile):
|
||||
self.extentions.append(StatefulTensorMemoryProfilerExtention(engine))
|
||||
|
||||
def prepare_trace(self) -> None:
|
||||
if hasattr(super(), 'prepare_trace'):
|
||||
if hasattr(super(), "prepare_trace"):
|
||||
super().prepare_trace()
|
||||
elif hasattr(super(), '_start_warmup'):
|
||||
elif hasattr(super(), "_start_warmup"):
|
||||
super()._start_warmup()
|
||||
for ext in self.extentions:
|
||||
ext.prepare_trace()
|
||||
@@ -160,9 +164,9 @@ class profile(torch_profile):
|
||||
self.prepare_trace()
|
||||
|
||||
def start_trace(self):
|
||||
if hasattr(super(), '_start_trace'):
|
||||
if hasattr(super(), "_start_trace"):
|
||||
super()._start_trace()
|
||||
elif hasattr(super(), 'start_trace'):
|
||||
elif hasattr(super(), "start_trace"):
|
||||
super().start_trace()
|
||||
for ext in self.extentions:
|
||||
ext.start_trace()
|
||||
@@ -171,9 +175,9 @@ class profile(torch_profile):
|
||||
self.start_trace()
|
||||
|
||||
def stop_trace(self):
|
||||
if hasattr(super(), '_stop_trace'):
|
||||
if hasattr(super(), "_stop_trace"):
|
||||
super()._stop_trace()
|
||||
elif hasattr(super(), 'stop_trace'):
|
||||
elif hasattr(super(), "stop_trace"):
|
||||
super().stop_trace()
|
||||
for ext in self.extentions:
|
||||
ext.stop_trace()
|
||||
@@ -186,15 +190,15 @@ class profile(torch_profile):
|
||||
Exports the collected trace in Chrome JSON format.
|
||||
"""
|
||||
assert self.profiler
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
|
||||
fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
|
||||
fp.close()
|
||||
retvalue = self.profiler.export_chrome_trace(fp.name)
|
||||
with open(fp.name) as fin:
|
||||
trace = json.load(fin)
|
||||
for ext in self.extentions:
|
||||
trace = ext.extend_chrome_trace(trace)
|
||||
open_func = gzip.open if path.endswith('.gz') else open
|
||||
with open_func(path, 'wt') as fout:
|
||||
open_func = gzip.open if path.endswith(".gz") else open
|
||||
with open_func(path, "wt") as fout:
|
||||
json.dump(trace, fout)
|
||||
|
||||
os.remove(fp.name)
|
||||
|
Reference in New Issue
Block a user