mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[zero] add chunk init function for users (#1729)
* add chunk manager init function * fix unit tests * add comment * add flush=True
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk
|
||||
from .manager import ChunkManager
|
||||
from .search_utils import clasify_params, search_chunk_configuration
|
||||
from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
|
||||
from .manager import ChunkManager
|
||||
from .search_utils import clasify_params, search_chunk_configuration
|
||||
from .utils import init_chunk_manager
|
||||
|
@@ -1,100 +1,108 @@
|
||||
import math
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from colossalai.tensor import ColoParameter
|
||||
|
||||
|
||||
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
||||
"""Filter those parameters whose size is too large from others.
|
||||
"""
|
||||
params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params_size_arr = np.array(params_size)
|
||||
|
||||
std = np.std(params_size_arr)
|
||||
mean = np.mean(params_size_arr)
|
||||
upper_limit = mean + 3 * std
|
||||
|
||||
for key in size_dict:
|
||||
org_list = size_dict[key]
|
||||
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
|
||||
|
||||
|
||||
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
||||
"""Get unused byte for a certain chunk size.
|
||||
"""
|
||||
acc = 0
|
||||
left = 0
|
||||
for s in size_list:
|
||||
if s > left:
|
||||
acc += left
|
||||
left = chunk_size
|
||||
left -= s
|
||||
return left + acc
|
||||
|
||||
|
||||
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
|
||||
params_dict: Dict[int, List[ColoParameter]] = dict()
|
||||
for param in model.parameters():
|
||||
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
if getattr(param, '_ddp_to_ignore', False):
|
||||
continue
|
||||
|
||||
param_key = param.process_group.dp_world_size()
|
||||
|
||||
if param_key not in params_dict:
|
||||
params_dict[param_key] = []
|
||||
params_dict[param_key].append(param)
|
||||
|
||||
return params_dict
|
||||
|
||||
|
||||
def search_chunk_configuration(
|
||||
model: nn.Module,
|
||||
search_range_mb: float,
|
||||
search_interval_byte: int, # hidden size is the best value for the interval
|
||||
min_chunk_size_mb: float = 32,
|
||||
filter_exlarge_params: bool = True) -> Dict:
|
||||
search_range_byte = round(search_range_mb * 1024**2)
|
||||
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
|
||||
assert search_range_byte >= 0
|
||||
|
||||
params_dict = clasify_params(model)
|
||||
config_dict: Dict[int, Dict] = dict()
|
||||
|
||||
size_dict: Dict[int, List[int]] = dict()
|
||||
for key in params_dict:
|
||||
params_list = params_dict[key]
|
||||
size_list = [p.numel() for p in params_list]
|
||||
# let small parameters keep gathered in CUDA all the time
|
||||
total_size = sum(size_list)
|
||||
if total_size < min_chunk_size_byte:
|
||||
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
|
||||
else:
|
||||
size_dict[key] = size_list
|
||||
|
||||
if filter_exlarge_params:
|
||||
_filter_exlarge_params(model, size_dict)
|
||||
|
||||
max_size = min_chunk_size_byte
|
||||
for key in size_dict:
|
||||
max_size = max(max_size, max(size_dict[key]))
|
||||
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
|
||||
|
||||
min_chunk_waste = float('+inf')
|
||||
best_chunk_size = start_size
|
||||
|
||||
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
|
||||
temp_waste = 0
|
||||
for key in size_dict:
|
||||
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
|
||||
if temp_waste < min_chunk_waste:
|
||||
min_chunk_waste = temp_waste
|
||||
best_chunk_size = chunk_size
|
||||
|
||||
for key in params_dict:
|
||||
if key in config_dict:
|
||||
continue
|
||||
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)
|
||||
|
||||
return config_dict
|
||||
import math
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.tensor import ColoParameter
|
||||
|
||||
|
||||
def in_ddp(param: nn.Parameter) -> bool:
|
||||
return not getattr(param, '_ddp_to_ignore', False)
|
||||
|
||||
|
||||
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
||||
"""Filter those parameters whose size is too large from others.
|
||||
"""
|
||||
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
|
||||
params_size_arr = np.array(params_size)
|
||||
|
||||
std = np.std(params_size_arr)
|
||||
mean = np.mean(params_size_arr)
|
||||
upper_limit = mean + 3 * std
|
||||
|
||||
for key in size_dict:
|
||||
org_list = size_dict[key]
|
||||
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
|
||||
|
||||
|
||||
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
||||
"""Get unused byte for a certain chunk size.
|
||||
"""
|
||||
acc = 0
|
||||
left = 0
|
||||
for s in size_list:
|
||||
if s > left:
|
||||
acc += left
|
||||
left = chunk_size
|
||||
left -= s
|
||||
return left + acc
|
||||
|
||||
|
||||
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
|
||||
"""Clasify each parameter by its size of DP group.
|
||||
"""
|
||||
params_dict: Dict[int, List[ColoParameter]] = dict()
|
||||
for param in model.parameters():
|
||||
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
if not in_ddp(param):
|
||||
continue
|
||||
|
||||
param_key = param.process_group.dp_world_size()
|
||||
|
||||
if param_key not in params_dict:
|
||||
params_dict[param_key] = []
|
||||
params_dict[param_key].append(param)
|
||||
|
||||
return params_dict
|
||||
|
||||
|
||||
def search_chunk_configuration(
|
||||
model: nn.Module,
|
||||
search_range_mb: float,
|
||||
search_interval_byte: int, # hidden size is the best value for the interval
|
||||
min_chunk_size_mb: float = 32,
|
||||
filter_exlarge_params: bool = True) -> Tuple[Dict, int]:
|
||||
search_range_byte = round(search_range_mb * 1024**2)
|
||||
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
|
||||
assert search_range_byte >= 0
|
||||
|
||||
params_dict = clasify_params(model)
|
||||
config_dict: Dict[int, Dict] = dict()
|
||||
|
||||
size_dict: Dict[int, List[int]] = dict()
|
||||
for key in params_dict:
|
||||
params_list = params_dict[key]
|
||||
size_list = [p.numel() for p in params_list]
|
||||
# let small parameters keep gathered in CUDA all the time
|
||||
total_size = sum(size_list)
|
||||
if total_size < min_chunk_size_byte:
|
||||
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
|
||||
else:
|
||||
size_dict[key] = size_list
|
||||
|
||||
if filter_exlarge_params:
|
||||
_filter_exlarge_params(model, size_dict)
|
||||
|
||||
max_size = min_chunk_size_byte
|
||||
for key in size_dict:
|
||||
max_size = max(max_size, max(size_dict[key]))
|
||||
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
|
||||
|
||||
min_chunk_waste = float('+inf')
|
||||
best_chunk_size = start_size
|
||||
|
||||
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
|
||||
temp_waste = 0
|
||||
for key in size_dict:
|
||||
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
|
||||
if temp_waste < min_chunk_waste:
|
||||
min_chunk_waste = temp_waste
|
||||
best_chunk_size = chunk_size
|
||||
|
||||
for key in params_dict:
|
||||
if key in config_dict:
|
||||
continue
|
||||
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)
|
||||
|
||||
return config_dict, min_chunk_waste
|
||||
|
58
colossalai/gemini/chunk/utils.py
Normal file
58
colossalai/gemini/chunk/utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
|
||||
|
||||
|
||||
def init_chunk_manager(model: nn.Module,
|
||||
init_device: Optional[torch.device] = None,
|
||||
hidden_dim: Optional[int] = None,
|
||||
search_range_mb: Optional[float] = None,
|
||||
min_chunk_size_mb: Optional[float] = None,
|
||||
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
|
||||
|
||||
kwargs_dict = dict()
|
||||
|
||||
if hidden_dim:
|
||||
search_interval_byte = hidden_dim
|
||||
else:
|
||||
search_interval_byte = 1024 # 1kb
|
||||
kwargs_dict["search_interval_byte"] = search_interval_byte
|
||||
|
||||
if search_range_mb:
|
||||
kwargs_dict["search_range_mb"] = search_range_mb
|
||||
|
||||
if min_chunk_size_mb:
|
||||
kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb
|
||||
|
||||
if filter_exlarge_params:
|
||||
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
|
||||
|
||||
params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)]
|
||||
total_size = sum(params_sizes) / 1024**2
|
||||
|
||||
dist.barrier()
|
||||
begine = time()
|
||||
|
||||
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
|
||||
|
||||
dist.barrier()
|
||||
end = time()
|
||||
span_s = end - begine
|
||||
wasted_size /= 1024**2
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
|
||||
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
|
||||
sep='',
|
||||
flush=True)
|
||||
dist.barrier()
|
||||
|
||||
chunk_manager = ChunkManager(config_dict, init_device)
|
||||
return chunk_manager
|
Reference in New Issue
Block a user