diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index b8536a1d5..fbea2c2c0 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -10,7 +10,7 @@ from .data_sampler import DataParallelSampler, get_dataloader from .gradient_accumulation import accumulate_gradient from .memory import report_memory_usage from .timer import MultiTimer, Timer -#from .tensor_detector import TensorDetector +from .tensor_detector import TensorDetector __all__ = [ 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0', diff --git a/colossalai/utils/tensor_detector/__init__.py b/colossalai/utils/tensor_detector/__init__.py new file mode 100644 index 000000000..0d35e6467 --- /dev/null +++ b/colossalai/utils/tensor_detector/__init__.py @@ -0,0 +1 @@ +from .tensor_detector import TensorDetector \ No newline at end of file diff --git a/colossalai/utils/tensor_detector/readme.md b/colossalai/utils/tensor_detector/readme.md new file mode 100644 index 000000000..840dc8f4e --- /dev/null +++ b/colossalai/utils/tensor_detector/readme.md @@ -0,0 +1,128 @@ +# Tensor Detector + +This tool supports you to detect tensors on both CPU and GPU. However, there will always be some strange tensors on CPU, including the rng state of PyTorch. + +## Example + +An example is worth than a thousand words. + +The code below defines a simple MLP module, with which we will show you how to use the tool. + +```python +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.mlp = nn.Sequential(nn.Linear(64, 8), + nn.ReLU(), + nn.Linear(8, 32)) + def forward(self, x): + return self.mlp(x) +``` + +And here is how to use the tool. + +```python +from colossalai.utils import TensorDetector + +# create random data +data = torch.rand(64, requires_grad=True).cuda() +data.retain_grad() +# create the module +model = MLP().cuda() +# create the detector +# by passing the model to the detector, it can distinguish module parameters from common tensors +detector = TensorDetector(include_cpu=False, module=model) +detector.detect() + +out = model(data) + +detector.detect() + +loss = out.sum() +loss.backward() + +detector.detect() +``` + +I have made some comments on the right of the output for your understanding. + +Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memery Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly. + +**The order of print is not equal to the order the tensor creates, but they are really close.** + +```bash +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 (64,) True torch.float32 256 B # data ++ mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB ++ mlp.0.bias cuda:0 (8,) True torch.float32 32 B ++ mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB ++ mlp.2.bias cuda:0 (32,) True torch.float32 128 B +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 27 +Totle GPU Memery Allocated on cuda:0 is 4.5 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 (8,) True torch.float32 32 B # activation ++ Tensor cuda:0 (32,) True torch.float32 128 B # output +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 30 +Totle GPU Memery Allocated on cuda:0 is 5.5 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 () True torch.float32 4 B # loss +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 32 +Totle GPU Memery Allocated on cuda:0 is 6.0 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor (with grad) cuda:0 (64,) True torch.float32 512 B # data with grad ++ mlp.0.weight (with grad) cuda:0 (8, 64) True torch.float32 4.0 KB # for use data.retain_grad() ++ mlp.0.bias (with grad) cuda:0 (8,) True torch.float32 64 B ++ mlp.2.weight (with grad) cuda:0 (32, 8) True torch.float32 2.0 KB ++ mlp.2.bias (with grad) cuda:0 (32,) True torch.float32 256 B + +- mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB +- mlp.0.bias cuda:0 (8,) True torch.float32 32 B +- mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB +- mlp.2.bias cuda:0 (32,) True torch.float32 128 B +- Tensor cuda:0 (64,) True torch.float32 256 B +- Tensor cuda:0 (8,) True torch.float32 32 B # deleted activation +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 34 +Totle GPU Memery Allocated on cuda:0 is 10.0 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 (64,) False torch.float32 256 B ++ Tensor cuda:0 (8, 64) False torch.float32 2.0 KB ++ Tensor cuda:0 (8,) False torch.float32 32 B ++ Tensor cuda:0 (32, 8) False torch.float32 1.0 KB ++ Tensor cuda:0 (32,) False torch.float32 128 B +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 36 +Totle GPU Memery Allocated on cuda:0 is 14.0 KB +------------------------------------------------------------------------------------------------------------ +``` + +## Reference + + This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py + and https://github.com/Oldpan/Pytorch-Memory-Utils + diff --git a/colossalai/utils/tensor_detector/tensor_detector.py b/colossalai/utils/tensor_detector/tensor_detector.py new file mode 100644 index 000000000..56b7e1d7f --- /dev/null +++ b/colossalai/utils/tensor_detector/tensor_detector.py @@ -0,0 +1,190 @@ +import gc +import inspect +import torch +import torch.nn as nn +from typing import Optional +from collections import defaultdict + + +LINE_WIDTH = 108 +LINE = '-' * LINE_WIDTH + '\n' + +class TensorDetector(): + def __init__(self, + show_info: bool = True, + log: str = None, + include_cpu: bool = False, + module: Optional[nn.Module] = None + ): + """This class is an detector to detect tensor on different devices. + + :param show_info: whether to print the info on screen, default True + :type show_info: bool + :param log: the file name to save the log + :type log: str + :param include_cpu: whether to detect tensor on cpu, default False + :type include_cpu: bool + :param module: when sending an `nn.Module` it, the detector can name the tensors detected better + :type module: Optional[nn.Module] + + """ + self.show_info = show_info + self.log = log + self.include_cpu = include_cpu + self.tensor_info = defaultdict(list) + self.saved_tensor_info = defaultdict(list) + self.order = [] + self.detected = [] + self.devices = [] + self.info = "" + + self.module = module + if isinstance(module, nn.Module): + # if module is an instance of nn.Module, we can name the parameter with its real name + for name, param in module.named_parameters(): + self.tensor_info[id(param)].append(name) + self.tensor_info[id(param)].append(param.device) + self.tensor_info[id(param)].append(param.shape) + self.tensor_info[id(param)].append(param.requires_grad) + self.tensor_info[id(param)].append(param.dtype) + self.tensor_info[id(param)].append(self.get_tensor_mem(param)) + + + def get_tensor_mem(self, tensor): + # calculate the memory occupied by a tensor + memory_size = tensor.element_size() * tensor.storage().size() + if (tensor.is_leaf or tensor.retains_grad) and tensor.grad is not None: + grad_memory_size = tensor.grad.element_size() * tensor.grad.storage().size() + memory_size += grad_memory_size + return self.mem_format(memory_size) + + + def mem_format(self, real_memory_size): + # format the tensor memory into a reasonal magnitude + if real_memory_size >= 2 ** 30: + return str(real_memory_size / (2 ** 30)) + ' GB' + if real_memory_size >= 2 ** 20: + return str(real_memory_size / (2 ** 20)) + ' MB' + if real_memory_size >= 2 ** 10: + return str(real_memory_size / (2 ** 10)) + ' KB' + return str(real_memory_size) + ' B' + + + def collect_tensors_state(self): + for obj in gc.get_objects(): + if torch.is_tensor(obj): + # skip cpu tensor when include_cpu is false and the tensor we have collected before + if (not self.include_cpu) and obj.device == torch.device('cpu'): + continue + self.detected.append(id(obj)) + # skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch + if id(obj) not in self.tensor_info: + + name = type(obj).__name__ + # after backward, we want to update the records, to show you the change + if isinstance(self.module, nn.Module) and name == 'Parameter': + if obj.grad is not None: + # with grad attached + for par_name, param in self.module.named_parameters(): + if param.requires_grad and param.grad.equal(obj.grad): + name = par_name + ' (with grad)' + else: + # with no grad attached + # there will be no new paramters created during running + # so it must be in saved_tensor_info + continue + # we can also marked common tensors as tensor(with grad) + if name == 'Tensor' and (obj.is_leaf or obj.retains_grad): + if obj.grad is not None: + name = name + ' (with grad)' + # in fact, common tensor have no grad + # unless you set retain_grad() + if id(obj) in self.saved_tensor_info.keys() and name == self.saved_tensor_info[id(obj)][0]: + continue + + self.tensor_info[id(obj)].append(name) + self.tensor_info[id(obj)].append(obj.device) + self.tensor_info[id(obj)].append(obj.shape) + self.tensor_info[id(obj)].append(obj.requires_grad) + self.tensor_info[id(obj)].append(obj.dtype) + self.tensor_info[id(obj)].append(self.get_tensor_mem(obj)) + # recorded the order we got the tensor + # by this we can guess the tensor easily + # it will record every tensor updated this turn + self.order.append(id(obj)) + # recorded all different devices + if obj.device not in self.devices: + self.devices.append(obj.device) + + + def print_tensors_state(self): + template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}' + self.info += LINE + self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') + self.info += '\n' + self.info += LINE + + # if a tensor updates this turn, and was recorded before + # it should be updated in the saved_tensor_info as well + outdated = [x for x in self.saved_tensor_info.keys() if x in self.order] + minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected] + minus = outdated + minus + if len(self.order) > 0: + for tensor_id in self.order: + self.info += template_format.format('+', + str(self.tensor_info[tensor_id][0]), + str(self.tensor_info[tensor_id][1]), + str(tuple(self.tensor_info[tensor_id][2])), + str(self.tensor_info[tensor_id][3]), + str(self.tensor_info[tensor_id][4]), + str(self.tensor_info[tensor_id][5])) + self.info += '\n' + if len(self.order) > 0 and len(minus) > 0: + self.info += '\n' + if len(minus) > 0: + for tensor_id in minus: + self.info += template_format.format('-', + str(self.saved_tensor_info[tensor_id][0]), + str(self.saved_tensor_info[tensor_id][1]), + str(tuple(self.saved_tensor_info[tensor_id][2])), + str(self.saved_tensor_info[tensor_id][3]), + str(self.saved_tensor_info[tensor_id][4]), + str(self.saved_tensor_info[tensor_id][5])) + self.info += '\n' + # deleted the updated tensor + self.saved_tensor_info.pop(tensor_id) + + + # trace where is the detect() + locate_info = inspect.stack()[2] + locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno) + + self.info += LINE + self.info += f"Detect Location: {locate_msg}\n" + for device in self.devices: + if device == torch.device('cpu'): + continue + gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device)) + self.info += f"Totle GPU Memery Allocated on {device} is {gpu_mem_alloc}\n" + self.info += LINE + self.info += '\n\n' + if self.show_info: + print(self.info) + if self.log is not None: + with open(self.log + '.log', 'a') as f: + f.write(self.info) + + + def detect(self, include_cpu = False): + self.include_cpu = include_cpu + self.collect_tensors_state() + self.print_tensors_state() + self.saved_tensor_info.update(self.tensor_info) + self.tensor_info.clear() + self.order = [] + self.detected = [] + self.info = "" + + def close(self): + self.saved_tensor_info.clear() + self.module = None \ No newline at end of file diff --git a/tests/test_utils/test_tensor_detector.py b/tests/test_utils/test_tensor_detector.py new file mode 100644 index 000000000..06b1c1846 --- /dev/null +++ b/tests/test_utils/test_tensor_detector.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn + +from colossalai.utils import TensorDetector + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.mlp = nn.Sequential(nn.Linear(64, 8), + nn.ReLU(), + nn.Linear(8, 32)) + + def forward(self, x): + return self.mlp(x) + +def test_tensor_detect(): + + data = torch.rand(64, requires_grad=True).cuda() + data.retain_grad() + model = MLP().cuda() + + detector = TensorDetector(log='test', include_cpu=False, module=model) + + detector.detect() + out = model(data) + + detector.detect() + loss = out.sum() + detector.detect() + loss.backward() + detector.detect() + model = MLP().cuda() + detector.detect() + detector.close() + torch.cuda.empty_cache() + +if __name__ == '__main__': + test_tensor_detect() \ No newline at end of file