mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
Refactored docstring to google style
This commit is contained in:
@@ -9,6 +9,7 @@ from collections import defaultdict
|
||||
LINE_WIDTH = 108
|
||||
LINE = '-' * LINE_WIDTH + '\n'
|
||||
|
||||
|
||||
class TensorDetector():
|
||||
def __init__(self,
|
||||
show_info: bool = True,
|
||||
@@ -16,17 +17,14 @@ class TensorDetector():
|
||||
include_cpu: bool = False,
|
||||
module: Optional[nn.Module] = None
|
||||
):
|
||||
"""This class is an detector to detect tensor on different devices.
|
||||
|
||||
:param show_info: whether to print the info on screen, default True
|
||||
:type show_info: bool
|
||||
:param log: the file name to save the log
|
||||
:type log: str
|
||||
:param include_cpu: whether to detect tensor on cpu, default False
|
||||
:type include_cpu: bool
|
||||
:param module: when sending an `nn.Module` it, the detector can name the tensors detected better
|
||||
:type module: Optional[nn.Module]
|
||||
"""This class is a detector to detect tensor on different devices.
|
||||
|
||||
Args:
|
||||
show_info (bool, optional): whether to print the info on screen, default True.
|
||||
log (str, optional): the file name to save the log. Defaults to None.
|
||||
include_cpu (bool, optional): whether to detect tensor on cpu, default False.
|
||||
module (Optional[:class:`nn.Module`]): when sending an ``nn.Module`` object,
|
||||
the detector can name the tensors detected better.
|
||||
"""
|
||||
self.show_info = show_info
|
||||
self.log = log
|
||||
@@ -48,7 +46,6 @@ class TensorDetector():
|
||||
self.tensor_info[id(param)].append(param.requires_grad)
|
||||
self.tensor_info[id(param)].append(param.dtype)
|
||||
self.tensor_info[id(param)].append(self.get_tensor_mem(param))
|
||||
|
||||
|
||||
def get_tensor_mem(self, tensor):
|
||||
# calculate the memory occupied by a tensor
|
||||
@@ -58,7 +55,6 @@ class TensorDetector():
|
||||
memory_size += grad_memory_size
|
||||
return self.mem_format(memory_size)
|
||||
|
||||
|
||||
def mem_format(self, real_memory_size):
|
||||
# format the tensor memory into a reasonal magnitude
|
||||
if real_memory_size >= 2 ** 30:
|
||||
@@ -68,7 +64,6 @@ class TensorDetector():
|
||||
if real_memory_size >= 2 ** 10:
|
||||
return str(real_memory_size / (2 ** 10)) + ' KB'
|
||||
return str(real_memory_size) + ' B'
|
||||
|
||||
|
||||
def collect_tensors_state(self):
|
||||
for obj in gc.get_objects():
|
||||
@@ -116,7 +111,6 @@ class TensorDetector():
|
||||
if obj.device not in self.devices:
|
||||
self.devices.append(obj.device)
|
||||
|
||||
|
||||
def print_tensors_state(self):
|
||||
template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}'
|
||||
self.info += LINE
|
||||
@@ -173,7 +167,6 @@ class TensorDetector():
|
||||
if self.log is not None:
|
||||
with open(self.log + '.log', 'a') as f:
|
||||
f.write(self.info)
|
||||
|
||||
|
||||
def detect(self, include_cpu = False):
|
||||
self.include_cpu = include_cpu
|
||||
|
Reference in New Issue
Block a user