mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
update buffer size calculation (#3871)
This commit is contained in:
parent
dbb9659099
commit
1ee247a51c
@ -1,7 +1,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Iterable, Mapping
|
from typing import Any, Callable, Iterable, Mapping
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -18,9 +18,19 @@ from colossalai.elixir.tensor import OutplaceTensor
|
|||||||
from colossalai.utils.model.experimental import LazyTensor
|
from colossalai.utils.model.experimental import LazyTensor
|
||||||
|
|
||||||
|
|
||||||
def is_leaf_module(m: nn.Module):
|
def calc_module_buffer(m: nn.Module, fused_check_func: Callable) -> int:
|
||||||
special_modules = [nn.MultiheadAttention]
|
special_modules = [nn.MultiheadAttention]
|
||||||
return type(m) in special_modules
|
buffer_size = 0
|
||||||
|
if type(m) in special_modules:
|
||||||
|
for p in m.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
buffer_size += p.numel()
|
||||||
|
else:
|
||||||
|
for p in m.parameters(recurse=False):
|
||||||
|
if p.requires_grad and not fused_check_func(p):
|
||||||
|
buffer_size += p.numel()
|
||||||
|
|
||||||
|
return buffer_size
|
||||||
|
|
||||||
|
|
||||||
def get_param_optim_data(param_data: torch.Tensor, param_dtype: torch.dtype):
|
def get_param_optim_data(param_data: torch.Tensor, param_dtype: torch.dtype):
|
||||||
@ -170,20 +180,18 @@ class ElixirModule(nn.Module):
|
|||||||
def __init_buffer_storage(self):
|
def __init_buffer_storage(self):
|
||||||
buffer_size = 0
|
buffer_size = 0
|
||||||
for submodule in self.modules():
|
for submodule in self.modules():
|
||||||
sum_param_size = 0
|
sub_size = calc_module_buffer(submodule, self.fetcher.is_in_fused)
|
||||||
recurse_flag = is_leaf_module(submodule)
|
buffer_size = max(buffer_size, sub_size)
|
||||||
for param in submodule.parameters(recurse=recurse_flag):
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
assert param.dtype == self.dtype
|
|
||||||
sum_param_size += param.numel()
|
|
||||||
buffer_size = max(buffer_size, sum_param_size)
|
|
||||||
self.buffer = BufferStore(buffer_size, self.dtype)
|
self.buffer = BufferStore(buffer_size, self.dtype)
|
||||||
print('module buffer', self.buffer)
|
print('module buffer', self.buffer)
|
||||||
|
|
||||||
def _gradient_handler(self, grad: torch.Tensor, param: nn.Parameter):
|
def _gradient_handler(self, grad: torch.Tensor, param: nn.Parameter):
|
||||||
# create an empty tensor
|
# create an empty tensor
|
||||||
fake_grad = self.buffer.empty_like(grad)
|
if param.numel() <= self.buffer.buffer_size:
|
||||||
|
fake_grad = self.buffer.empty_like(grad)
|
||||||
|
else:
|
||||||
|
fake_grad = torch.empty_like(grad)
|
||||||
|
fake_grad.storage().resize_(0)
|
||||||
|
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
chunk = self.fetcher.get_one_chunk(param)
|
chunk = self.fetcher.get_one_chunk(param)
|
||||||
|
@ -70,6 +70,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
|
|||||||
if self.clipping_flag:
|
if self.clipping_flag:
|
||||||
assert norm_type == 2.0, 'ElixirOptimizer only supports L2 norm now'
|
assert norm_type == 2.0, 'ElixirOptimizer only supports L2 norm now'
|
||||||
|
|
||||||
|
self.max_fake_param_size = 0
|
||||||
self.__init__optimizer()
|
self.__init__optimizer()
|
||||||
|
|
||||||
# Grad scaler
|
# Grad scaler
|
||||||
@ -90,6 +91,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
|
|||||||
if init_step:
|
if init_step:
|
||||||
# allocate memory before training
|
# allocate memory before training
|
||||||
self.__zero_step()
|
self.__zero_step()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if self.clipping_flag:
|
if self.clipping_flag:
|
||||||
for param_chunk in self.param_chunk_set:
|
for param_chunk in self.param_chunk_set:
|
||||||
@ -98,10 +100,15 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
|
|||||||
def __zero_step(self):
|
def __zero_step(self):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
cpu_buffer = BufferStore(self.module.buffer.buffer_size, self.module.buffer.buffer_dtype, 'cpu')
|
compute_type = self.module.buffer.buffer_dtype
|
||||||
buffer_dict = dict(cpu=cpu_buffer, cuda=self.module.buffer)
|
device_list = ['cpu', 'cuda']
|
||||||
for _, zero_buffer in buffer_dict.items():
|
buffer_dict = dict()
|
||||||
zero_buffer.zeros()
|
|
||||||
|
for device in device_list:
|
||||||
|
temp_buffer = BufferStore(self.max_fake_param_size, compute_type, device)
|
||||||
|
buffer_dict[device] = temp_buffer
|
||||||
|
for _, temp_buffer in buffer_dict.items():
|
||||||
|
temp_buffer.zeros()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for fake_param in group['params']:
|
for fake_param in group['params']:
|
||||||
@ -263,6 +270,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
|
|||||||
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
||||||
self.param_to_optim_chunk[fake_param] = param_chunk.paired_chunk
|
self.param_to_optim_chunk[fake_param] = param_chunk.paired_chunk
|
||||||
self.param_to_range[fake_param] = range_pair
|
self.param_to_range[fake_param] = range_pair
|
||||||
|
self.max_fake_param_size = max(self.max_fake_param_size, range_pair[1] - range_pair[0])
|
||||||
|
|
||||||
fake_params_list.append(fake_param)
|
fake_params_list.append(fake_param)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user