Refactored docstring to google style

This commit is contained in:
Liang Bowen
2022-03-25 13:02:39 +08:00
committed by アマデウス
parent 53b1b6e340
commit ec5086c49c
94 changed files with 3389 additions and 2982 deletions

View File

@@ -9,6 +9,7 @@ from collections import defaultdict
LINE_WIDTH = 108
LINE = '-' * LINE_WIDTH + '\n'
class TensorDetector():
def __init__(self,
show_info: bool = True,
@@ -16,17 +17,14 @@ class TensorDetector():
include_cpu: bool = False,
module: Optional[nn.Module] = None
):
"""This class is an detector to detect tensor on different devices.
:param show_info: whether to print the info on screen, default True
:type show_info: bool
:param log: the file name to save the log
:type log: str
:param include_cpu: whether to detect tensor on cpu, default False
:type include_cpu: bool
:param module: when sending an `nn.Module` it, the detector can name the tensors detected better
:type module: Optional[nn.Module]
"""This class is a detector to detect tensor on different devices.
Args:
show_info (bool, optional): whether to print the info on screen, default True.
log (str, optional): the file name to save the log. Defaults to None.
include_cpu (bool, optional): whether to detect tensor on cpu, default False.
module (Optional[:class:`nn.Module`]): when sending an ``nn.Module`` object,
the detector can name the tensors detected better.
"""
self.show_info = show_info
self.log = log
@@ -48,7 +46,6 @@ class TensorDetector():
self.tensor_info[id(param)].append(param.requires_grad)
self.tensor_info[id(param)].append(param.dtype)
self.tensor_info[id(param)].append(self.get_tensor_mem(param))
def get_tensor_mem(self, tensor):
# calculate the memory occupied by a tensor
@@ -58,7 +55,6 @@ class TensorDetector():
memory_size += grad_memory_size
return self.mem_format(memory_size)
def mem_format(self, real_memory_size):
# format the tensor memory into a reasonal magnitude
if real_memory_size >= 2 ** 30:
@@ -68,7 +64,6 @@ class TensorDetector():
if real_memory_size >= 2 ** 10:
return str(real_memory_size / (2 ** 10)) + ' KB'
return str(real_memory_size) + ' B'
def collect_tensors_state(self):
for obj in gc.get_objects():
@@ -116,7 +111,6 @@ class TensorDetector():
if obj.device not in self.devices:
self.devices.append(obj.device)
def print_tensors_state(self):
template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}'
self.info += LINE
@@ -173,7 +167,6 @@ class TensorDetector():
if self.log is not None:
with open(self.log + '.log', 'a') as f:
f.write(self.info)
def detect(self, include_cpu = False):
self.include_cpu = include_cpu