mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ProfilerExtension(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def prepare_trace(self):
|
||||
pass
|
||||
|
@@ -3,4 +3,4 @@ from .mem_profiler import MemProfiler
|
||||
from .pcie_profiler import PcieProfiler
|
||||
from .prof_utils import BaseProfiler, ProfilerContext
|
||||
|
||||
__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']
|
||||
__all__ = ["BaseProfiler", "CommProfiler", "PcieProfiler", "MemProfiler", "ProfilerContext"]
|
||||
|
@@ -20,14 +20,14 @@ def _get_code_location(depth: int):
|
||||
upper_frame = inspect.stack()[i]
|
||||
function_name = inspect.stack()[i - 1].function
|
||||
ret.append(upper_frame.filename)
|
||||
ret.append('(')
|
||||
ret.append("(")
|
||||
ret.append(str(upper_frame.lineno))
|
||||
ret.append('): ')
|
||||
ret.append("): ")
|
||||
ret.append(function_name)
|
||||
if i != length - 1:
|
||||
ret.append('\n')
|
||||
ret.append("\n")
|
||||
|
||||
return ''.join(ret)
|
||||
return "".join(ret)
|
||||
|
||||
|
||||
torch_all_reduce = dist.all_reduce
|
||||
@@ -42,7 +42,7 @@ class CommEvent(object):
|
||||
volume recording.
|
||||
"""
|
||||
|
||||
def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0):
|
||||
def __init__(self, count: int = 0, comm_vol: float = 0.0, cuda_time: int = 0):
|
||||
self.self_count = count
|
||||
self.self_comm_vol = comm_vol
|
||||
self.self_cuda_time = cuda_time
|
||||
@@ -54,8 +54,7 @@ class CommEvent(object):
|
||||
|
||||
|
||||
class CommProfiler(BaseProfiler):
|
||||
"""Communication profiler. Records all communication events.
|
||||
"""
|
||||
"""Communication profiler. Records all communication events."""
|
||||
|
||||
def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
|
||||
super().__init__(profiler_name="Collective_Communication", priority=0)
|
||||
@@ -114,8 +113,10 @@ class CommProfiler(BaseProfiler):
|
||||
res.append(sep)
|
||||
|
||||
if self.warn_flag:
|
||||
append("Warning: there exists multiple communication operations in the same time. As a result, "
|
||||
"the profiling result is not accurate.")
|
||||
append(
|
||||
"Warning: there exists multiple communication operations in the same time. As a result, "
|
||||
"the profiling result is not accurate."
|
||||
)
|
||||
|
||||
if self.total_cuda_time == 0:
|
||||
return "No collective communication has been called yet!"
|
||||
@@ -126,24 +127,29 @@ class CommProfiler(BaseProfiler):
|
||||
append("total number of calls: {}".format(self.total_count))
|
||||
append("All events:")
|
||||
|
||||
separation = '-' * 74
|
||||
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
|
||||
separation = "-" * 74
|
||||
row_format = "{:^10}" + "{:^12}" * 2 + "{:^16}" + "{:^12}" * 2
|
||||
|
||||
append(separation)
|
||||
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
|
||||
append(row_format.format("Location", "GPU time", "Percentage", "Comm volume", "Bandwidth", "Num of calls"))
|
||||
append(separation)
|
||||
|
||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
|
||||
for location, event in show_list:
|
||||
append(location)
|
||||
append(
|
||||
row_format.format('', _format_time(event.self_cuda_time),
|
||||
'{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0),
|
||||
_format_memory(event.self_comm_vol),
|
||||
_format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count))
|
||||
row_format.format(
|
||||
"",
|
||||
_format_time(event.self_cuda_time),
|
||||
"{:.1f}%".format(event.self_cuda_time / self.total_cuda_time * 100.0),
|
||||
_format_memory(event.self_comm_vol),
|
||||
_format_bandwidth(event.self_comm_vol, event.self_cuda_time),
|
||||
event.self_count,
|
||||
)
|
||||
)
|
||||
append()
|
||||
|
||||
return ''.join(res)
|
||||
return "".join(res)
|
||||
|
||||
@property
|
||||
def has_aync_op(self):
|
||||
@@ -195,8 +201,7 @@ class CommProfiler(BaseProfiler):
|
||||
|
||||
|
||||
class CommHandler(object):
|
||||
"""Communication handler. A dummy handler to wait aync operations.
|
||||
"""
|
||||
"""Communication handler. A dummy handler to wait aync operations."""
|
||||
|
||||
def __init__(self, profiler: CommProfiler):
|
||||
super().__init__()
|
||||
@@ -212,11 +217,9 @@ def async_check(profiler: CommProfiler):
|
||||
profiler.wait_async_op()
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
def all_reduce(
|
||||
tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, group=None, async_op: bool = False, profiler: CommProfiler = None
|
||||
) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
@@ -231,12 +234,14 @@ def all_reduce(tensor: torch.Tensor,
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def reduce_scatter(output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
def reduce_scatter(
|
||||
output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None,
|
||||
) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
@@ -254,11 +259,13 @@ def reduce_scatter(output: torch.Tensor,
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def all_gather(tensor_list: List[torch.Tensor],
|
||||
tensor: torch.Tensor,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
def all_gather(
|
||||
tensor_list: List[torch.Tensor],
|
||||
tensor: torch.Tensor,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None,
|
||||
) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
@@ -276,11 +283,9 @@ def all_gather(tensor_list: List[torch.Tensor],
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor,
|
||||
src: int,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
def broadcast(
|
||||
tensor: torch.Tensor, src: int, group=None, async_op: bool = False, profiler: CommProfiler = None
|
||||
) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
|
||||
@@ -293,12 +298,14 @@ def broadcast(tensor: torch.Tensor,
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def reduce(tensor: torch.Tensor,
|
||||
dst: int,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
def reduce(
|
||||
tensor: torch.Tensor,
|
||||
dst: int,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None,
|
||||
) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
|
||||
|
@@ -18,6 +18,7 @@ def _get_size(dtype: str):
|
||||
def _get_numel(my_list: List[int]) -> int:
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
return reduce(mul, my_list)
|
||||
|
||||
|
||||
@@ -27,12 +28,11 @@ def _reduce_location(locations: List[str]) -> str:
|
||||
ret.append(lo)
|
||||
ret.append("\n")
|
||||
ret = ret[:-1]
|
||||
return ''.join(ret)
|
||||
return "".join(ret)
|
||||
|
||||
|
||||
class PcieEvent(object):
|
||||
"""Pcie Event.
|
||||
"""
|
||||
"""Pcie Event."""
|
||||
|
||||
def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0):
|
||||
self.count = count
|
||||
@@ -73,12 +73,9 @@ class PcieProfiler(BaseProfiler):
|
||||
self.profiler = None
|
||||
|
||||
def enable(self):
|
||||
self.profiler = profile(enabled=True,
|
||||
use_cuda=True,
|
||||
use_cpu=True,
|
||||
use_kineto=True,
|
||||
record_shapes=True,
|
||||
with_stack=True)
|
||||
self.profiler = profile(
|
||||
enabled=True, use_cuda=True, use_cpu=True, use_kineto=True, record_shapes=True, with_stack=True
|
||||
)
|
||||
self.profiler.__enter__()
|
||||
|
||||
def disable(self):
|
||||
@@ -92,15 +89,15 @@ class PcieProfiler(BaseProfiler):
|
||||
if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0:
|
||||
continue
|
||||
current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total)
|
||||
code_location = _reduce_location(event.stack[:self.depth])
|
||||
code_location = _reduce_location(event.stack[: self.depth])
|
||||
if code_location in self.ops_record:
|
||||
self.ops_record[code_location].add(current_comm_event)
|
||||
else:
|
||||
self.ops_record[code_location] = current_comm_event
|
||||
elif 'Memcpy HtoD' in event.name:
|
||||
elif "Memcpy HtoD" in event.name:
|
||||
self.h2d_count += 1
|
||||
self.h2d_time += event.cuda_time_total
|
||||
elif 'Memcpy DtoH' in event.name:
|
||||
elif "Memcpy DtoH" in event.name:
|
||||
self.d2h_count += 1
|
||||
self.d2h_time += event.cuda_time_total
|
||||
|
||||
@@ -132,19 +129,25 @@ class PcieProfiler(BaseProfiler):
|
||||
|
||||
append("Possible data transmission events in PCIE:")
|
||||
|
||||
separation = '-' * 62
|
||||
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
|
||||
separation = "-" * 62
|
||||
row_format = "{:^10}" + "{:^12}" + "{:^16}" + "{:^12}" * 2
|
||||
|
||||
append(separation)
|
||||
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
|
||||
append(row_format.format("Location", "GPU time", "Trans volume", "Bandwidth", "Num of calls"))
|
||||
append(separation)
|
||||
|
||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
|
||||
for location, event in show_list:
|
||||
append(location)
|
||||
append(
|
||||
row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol),
|
||||
_format_bandwidth(event.pcie_vol, event.cuda_time), event.count))
|
||||
row_format.format(
|
||||
"",
|
||||
_format_time(event.cuda_time),
|
||||
_format_memory(event.pcie_vol),
|
||||
_format_bandwidth(event.pcie_vol, event.cuda_time),
|
||||
event.count,
|
||||
)
|
||||
)
|
||||
append()
|
||||
|
||||
return ''.join(res)
|
||||
return "".join(res)
|
||||
|
@@ -11,10 +11,10 @@ def _format_time(time_us):
|
||||
US_IN_SECOND = 1000.0 * 1000.0
|
||||
US_IN_MS = 1000.0
|
||||
if time_us >= US_IN_SECOND:
|
||||
return '{:.3f}s'.format(time_us / US_IN_SECOND)
|
||||
return "{:.3f}s".format(time_us / US_IN_SECOND)
|
||||
if time_us >= US_IN_MS:
|
||||
return '{:.3f}ms'.format(time_us / US_IN_MS)
|
||||
return '{:.3f}us'.format(time_us)
|
||||
return "{:.3f}ms".format(time_us / US_IN_MS)
|
||||
return "{:.3f}us".format(time_us)
|
||||
|
||||
|
||||
# copied from high version pytorch to support low version
|
||||
@@ -23,28 +23,27 @@ def _format_memory(nbytes):
|
||||
KB = 1024
|
||||
MB = 1024 * KB
|
||||
GB = 1024 * MB
|
||||
if (abs(nbytes) >= GB):
|
||||
return '{:.2f} GB'.format(nbytes * 1.0 / GB)
|
||||
elif (abs(nbytes) >= MB):
|
||||
return '{:.2f} MB'.format(nbytes * 1.0 / MB)
|
||||
elif (abs(nbytes) >= KB):
|
||||
return '{:.2f} KB'.format(nbytes * 1.0 / KB)
|
||||
if abs(nbytes) >= GB:
|
||||
return "{:.2f} GB".format(nbytes * 1.0 / GB)
|
||||
elif abs(nbytes) >= MB:
|
||||
return "{:.2f} MB".format(nbytes * 1.0 / MB)
|
||||
elif abs(nbytes) >= KB:
|
||||
return "{:.2f} KB".format(nbytes * 1.0 / KB)
|
||||
else:
|
||||
return str(nbytes) + ' B'
|
||||
return str(nbytes) + " B"
|
||||
|
||||
|
||||
def _format_bandwidth(volume: float or int, time_us: int):
|
||||
sec_div_mb = (1000.0 / 1024.0)**2
|
||||
sec_div_mb = (1000.0 / 1024.0) ** 2
|
||||
mb_per_sec = volume / time_us * sec_div_mb
|
||||
|
||||
if mb_per_sec >= 1024.0:
|
||||
return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
|
||||
return "{:.3f} GB/s".format(mb_per_sec / 1024.0)
|
||||
else:
|
||||
return '{:.3f} MB/s'.format(mb_per_sec)
|
||||
return "{:.3f} MB/s".format(mb_per_sec)
|
||||
|
||||
|
||||
class BaseProfiler(ABC):
|
||||
|
||||
def __init__(self, profiler_name: str, priority: int):
|
||||
self.name = profiler_name
|
||||
self.priority = priority
|
||||
@@ -111,8 +110,9 @@ class ProfilerContext(object):
|
||||
def to_tensorboard(self, writer):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
assert isinstance(writer, SummaryWriter), \
|
||||
f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.'
|
||||
assert isinstance(
|
||||
writer, SummaryWriter
|
||||
), f"torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}."
|
||||
|
||||
for prof in self.profilers:
|
||||
prof.to_tensorboard(writer)
|
||||
@@ -124,7 +124,7 @@ class ProfilerContext(object):
|
||||
if not log_dir.exists():
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
for prof in self.profilers:
|
||||
log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log')
|
||||
log_file = log_dir.joinpath(f"{prof.name}_rank_{gpc.get_global_rank()}.log")
|
||||
prof.to_file(log_file)
|
||||
|
||||
def show(self):
|
||||
|
@@ -120,26 +120,30 @@ class profile(torch_profile):
|
||||
p.step()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
activities: Optional[Iterable[ProfilerActivity]] = None,
|
||||
schedule: Optional[Callable[[int], ProfilerAction]] = None,
|
||||
on_trace_ready: Optional[Callable[..., Any]] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
record_shapes: bool = False,
|
||||
profile_memory: bool = False,
|
||||
with_stack: bool = False,
|
||||
with_flops: bool = False,
|
||||
with_modules: bool = False,
|
||||
profile_stateful_tensor_memory: bool = False) -> None:
|
||||
super().__init__(activities=activities,
|
||||
schedule=schedule,
|
||||
on_trace_ready=on_trace_ready,
|
||||
record_shapes=record_shapes,
|
||||
profile_memory=profile_memory,
|
||||
with_stack=with_stack,
|
||||
with_flops=with_flops,
|
||||
with_modules=with_modules)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
activities: Optional[Iterable[ProfilerActivity]] = None,
|
||||
schedule: Optional[Callable[[int], ProfilerAction]] = None,
|
||||
on_trace_ready: Optional[Callable[..., Any]] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
record_shapes: bool = False,
|
||||
profile_memory: bool = False,
|
||||
with_stack: bool = False,
|
||||
with_flops: bool = False,
|
||||
with_modules: bool = False,
|
||||
profile_stateful_tensor_memory: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
activities=activities,
|
||||
schedule=schedule,
|
||||
on_trace_ready=on_trace_ready,
|
||||
record_shapes=record_shapes,
|
||||
profile_memory=profile_memory,
|
||||
with_stack=with_stack,
|
||||
with_flops=with_flops,
|
||||
with_modules=with_modules,
|
||||
)
|
||||
self._logger = get_dist_logger()
|
||||
self.extentions: List[ProfilerExtension] = []
|
||||
if profile_stateful_tensor_memory:
|
||||
@@ -149,9 +153,9 @@ class profile(torch_profile):
|
||||
self.extentions.append(StatefulTensorMemoryProfilerExtention(engine))
|
||||
|
||||
def prepare_trace(self) -> None:
|
||||
if hasattr(super(), 'prepare_trace'):
|
||||
if hasattr(super(), "prepare_trace"):
|
||||
super().prepare_trace()
|
||||
elif hasattr(super(), '_start_warmup'):
|
||||
elif hasattr(super(), "_start_warmup"):
|
||||
super()._start_warmup()
|
||||
for ext in self.extentions:
|
||||
ext.prepare_trace()
|
||||
@@ -160,9 +164,9 @@ class profile(torch_profile):
|
||||
self.prepare_trace()
|
||||
|
||||
def start_trace(self):
|
||||
if hasattr(super(), '_start_trace'):
|
||||
if hasattr(super(), "_start_trace"):
|
||||
super()._start_trace()
|
||||
elif hasattr(super(), 'start_trace'):
|
||||
elif hasattr(super(), "start_trace"):
|
||||
super().start_trace()
|
||||
for ext in self.extentions:
|
||||
ext.start_trace()
|
||||
@@ -171,9 +175,9 @@ class profile(torch_profile):
|
||||
self.start_trace()
|
||||
|
||||
def stop_trace(self):
|
||||
if hasattr(super(), '_stop_trace'):
|
||||
if hasattr(super(), "_stop_trace"):
|
||||
super()._stop_trace()
|
||||
elif hasattr(super(), 'stop_trace'):
|
||||
elif hasattr(super(), "stop_trace"):
|
||||
super().stop_trace()
|
||||
for ext in self.extentions:
|
||||
ext.stop_trace()
|
||||
@@ -186,15 +190,15 @@ class profile(torch_profile):
|
||||
Exports the collected trace in Chrome JSON format.
|
||||
"""
|
||||
assert self.profiler
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
|
||||
fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
|
||||
fp.close()
|
||||
retvalue = self.profiler.export_chrome_trace(fp.name)
|
||||
with open(fp.name) as fin:
|
||||
trace = json.load(fin)
|
||||
for ext in self.extentions:
|
||||
trace = ext.extend_chrome_trace(trace)
|
||||
open_func = gzip.open if path.endswith('.gz') else open
|
||||
with open_func(path, 'wt') as fout:
|
||||
open_func = gzip.open if path.endswith(".gz") else open
|
||||
with open_func(path, "wt") as fout:
|
||||
json.dump(trace, fout)
|
||||
|
||||
os.remove(fp.name)
|
||||
|
@@ -22,11 +22,11 @@ def get_timestamp_us():
|
||||
|
||||
|
||||
def generic_instant_event(name, pid, tid, timestamp, args):
|
||||
return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args}
|
||||
return {"ph": "i", "s": "t", "name": name, "pid": pid, "tid": tid, "ts": timestamp, "args": args}
|
||||
|
||||
|
||||
class StatefulTensorMemoryEvent:
|
||||
EVENT_NAME = '[statefulTensorMemory]'
|
||||
EVENT_NAME = "[statefulTensorMemory]"
|
||||
|
||||
def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None:
|
||||
self.pid = os.getpid()
|
||||
@@ -37,22 +37,23 @@ class StatefulTensorMemoryEvent:
|
||||
self.bytes = bytes_
|
||||
|
||||
def state_dict(self):
|
||||
return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, {
|
||||
'Device Type': self.device_type.value,
|
||||
'Device Id': self.device_id,
|
||||
'Bytes': self.bytes
|
||||
})
|
||||
return generic_instant_event(
|
||||
StatefulTensorMemoryEvent.EVENT_NAME,
|
||||
self.pid,
|
||||
self.tid,
|
||||
self.timestamp,
|
||||
{"Device Type": self.device_type.value, "Device Id": self.device_id, "Bytes": self.bytes},
|
||||
)
|
||||
|
||||
|
||||
class StatefulTensorMemoryTracer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: List[StatefulTensorMemoryEvent] = []
|
||||
self._tracing = False
|
||||
|
||||
def sample(self):
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"]
|
||||
cpu_mem = StatefulTensor.GST_MGR.total_mem["cpu"]
|
||||
timestamp = get_timestamp_us()
|
||||
if self._tracing:
|
||||
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem))
|
||||
@@ -70,7 +71,6 @@ class StatefulTensorMemoryTracer:
|
||||
|
||||
|
||||
class StatefulTensorMemoryTracerHook(BaseOpHook):
|
||||
|
||||
def __init__(self, tracer: StatefulTensorMemoryTracer):
|
||||
super().__init__()
|
||||
self.tracer = tracer
|
||||
@@ -104,7 +104,6 @@ class StatefulTensorMemoryTracerHook(BaseOpHook):
|
||||
|
||||
|
||||
class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
|
||||
|
||||
def __init__(self, engine: Engine) -> None:
|
||||
self.engine = engine
|
||||
self.tracer = StatefulTensorMemoryTracer()
|
||||
@@ -131,5 +130,5 @@ class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
|
||||
# self.hook_registered = False
|
||||
|
||||
def extend_chrome_trace(self, trace: dict) -> dict:
|
||||
trace['traceEvents'].extend(self.tracer.state_dict())
|
||||
trace["traceEvents"].extend(self.tracer.state_dict())
|
||||
return trace
|
||||
|
Reference in New Issue
Block a user