diff --git a/colossalai/elixir/wrapper/module.py b/colossalai/elixir/wrapper/module.py index cbc3a69c1..a6980e47f 100644 --- a/colossalai/elixir/wrapper/module.py +++ b/colossalai/elixir/wrapper/module.py @@ -1,7 +1,7 @@ from collections import defaultdict from copy import copy from functools import partial -from typing import Any, Iterable, Mapping +from typing import Any, Callable, Iterable, Mapping import torch import torch.distributed as dist @@ -18,9 +18,19 @@ from colossalai.elixir.tensor import OutplaceTensor from colossalai.utils.model.experimental import LazyTensor -def is_leaf_module(m: nn.Module): +def calc_module_buffer(m: nn.Module, fused_check_func: Callable) -> int: special_modules = [nn.MultiheadAttention] - return type(m) in special_modules + buffer_size = 0 + if type(m) in special_modules: + for p in m.parameters(): + if p.requires_grad: + buffer_size += p.numel() + else: + for p in m.parameters(recurse=False): + if p.requires_grad and not fused_check_func(p): + buffer_size += p.numel() + + return buffer_size def get_param_optim_data(param_data: torch.Tensor, param_dtype: torch.dtype): @@ -170,20 +180,18 @@ class ElixirModule(nn.Module): def __init_buffer_storage(self): buffer_size = 0 for submodule in self.modules(): - sum_param_size = 0 - recurse_flag = is_leaf_module(submodule) - for param in submodule.parameters(recurse=recurse_flag): - if not param.requires_grad: - continue - assert param.dtype == self.dtype - sum_param_size += param.numel() - buffer_size = max(buffer_size, sum_param_size) + sub_size = calc_module_buffer(submodule, self.fetcher.is_in_fused) + buffer_size = max(buffer_size, sub_size) self.buffer = BufferStore(buffer_size, self.dtype) print('module buffer', self.buffer) def _gradient_handler(self, grad: torch.Tensor, param: nn.Parameter): # create an empty tensor - fake_grad = self.buffer.empty_like(grad) + if param.numel() <= self.buffer.buffer_size: + fake_grad = self.buffer.empty_like(grad) + else: + fake_grad = torch.empty_like(grad) + fake_grad.storage().resize_(0) with torch._C.DisableTorchFunction(): chunk = self.fetcher.get_one_chunk(param) diff --git a/colossalai/elixir/wrapper/optimizer.py b/colossalai/elixir/wrapper/optimizer.py index aba299611..d120dd414 100644 --- a/colossalai/elixir/wrapper/optimizer.py +++ b/colossalai/elixir/wrapper/optimizer.py @@ -70,6 +70,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer): if self.clipping_flag: assert norm_type == 2.0, 'ElixirOptimizer only supports L2 norm now' + self.max_fake_param_size = 0 self.__init__optimizer() # Grad scaler @@ -90,6 +91,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer): if init_step: # allocate memory before training self.__zero_step() + torch.cuda.empty_cache() if self.clipping_flag: for param_chunk in self.param_chunk_set: @@ -98,10 +100,15 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer): def __zero_step(self): torch.cuda.empty_cache() - cpu_buffer = BufferStore(self.module.buffer.buffer_size, self.module.buffer.buffer_dtype, 'cpu') - buffer_dict = dict(cpu=cpu_buffer, cuda=self.module.buffer) - for _, zero_buffer in buffer_dict.items(): - zero_buffer.zeros() + compute_type = self.module.buffer.buffer_dtype + device_list = ['cpu', 'cuda'] + buffer_dict = dict() + + for device in device_list: + temp_buffer = BufferStore(self.max_fake_param_size, compute_type, device) + buffer_dict[device] = temp_buffer + for _, temp_buffer in buffer_dict.items(): + temp_buffer.zeros() for group in self.param_groups: for fake_param in group['params']: @@ -263,6 +270,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer): fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) self.param_to_optim_chunk[fake_param] = param_chunk.paired_chunk self.param_to_range[fake_param] = range_pair + self.max_fake_param_size = max(self.max_fake_param_size, range_pair[1] - range_pair[0]) fake_params_list.append(fake_param)