mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-01 11:55:23 +00:00
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
190 lines
6.2 KiB
Python
190 lines
6.2 KiB
Python
import math
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from colossalai.tensor import ColoParameter
|
|
from colossalai.utils import is_ddp_ignored
|
|
from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
|
|
|
|
|
|
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
|
"""_filter_exlarge_params
|
|
|
|
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
|
|
|
|
Args:
|
|
model (nn.Module): the model.
|
|
size_dict (Dict[int, List[int]]): the size dict of parameters.
|
|
"""
|
|
agg_size_list = []
|
|
for key in size_dict:
|
|
agg_size_list.extend(size_dict[key])
|
|
|
|
if len(agg_size_list) == 0:
|
|
return
|
|
|
|
params_size_arr = np.array(agg_size_list)
|
|
|
|
std = np.std(params_size_arr)
|
|
mean = np.mean(params_size_arr)
|
|
upper_limit = mean + 3 * std
|
|
|
|
for key in size_dict:
|
|
org_list = size_dict[key]
|
|
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
|
|
|
|
|
|
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
|
"""_get_unused_byte
|
|
|
|
Get unused byte for a certain chunk size.
|
|
|
|
Args:
|
|
size_list (List[int]): the size list of parameters.
|
|
chunk_size (int): the chunk size.
|
|
|
|
Returns:
|
|
int: the unused byte.
|
|
"""
|
|
acc = 0
|
|
left = 0
|
|
for s in size_list:
|
|
if s > left:
|
|
acc += left
|
|
left = chunk_size
|
|
left -= s
|
|
return left + acc
|
|
|
|
|
|
def _tensor_numel(local_param: ColoParameter) -> int:
|
|
"""_tensor_numel
|
|
|
|
Get the number of elements of a tensor.
|
|
|
|
Args:
|
|
local_param (ColoParameter): The local parameter.
|
|
strict_ddp_flag (bool): whether to enable the strict ddp mode.
|
|
|
|
Returns:
|
|
int: the number of elements.
|
|
"""
|
|
# TODO(ver217): support dtensor here
|
|
return local_param.numel()
|
|
|
|
|
|
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
|
process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
|
|
"""classify_params_by_dp_degree
|
|
|
|
Classify the parameters by their dp degree
|
|
|
|
Args:
|
|
param_order (OrderedParamGenerator): the order of param be vised
|
|
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False.
|
|
|
|
Returns:
|
|
Dict[int, List[ColoParameter]]: a dict contains the classification results.
|
|
The keys are dp_degrees and the values are parameters.
|
|
"""
|
|
params_dict: Dict[int, List[ColoParameter]] = dict()
|
|
for param in param_order.generate():
|
|
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
|
if is_ddp_ignored(param):
|
|
continue
|
|
param_key = dist.get_world_size(process_group)
|
|
|
|
if param_key not in params_dict:
|
|
params_dict[param_key] = []
|
|
params_dict[param_key].append(param)
|
|
|
|
return params_dict
|
|
|
|
|
|
def search_chunk_configuration(
|
|
model: nn.Module,
|
|
search_range_m: float,
|
|
search_interval: int, # hidden size is the best value for the interval
|
|
min_chunk_size_m: float = 32,
|
|
filter_exlarge_params: bool = True,
|
|
strict_ddp_flag: bool = False,
|
|
process_group: Optional[ProcessGroup] = None,
|
|
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
|
"""search_chunk_configuration
|
|
|
|
Search the chunk configuration for a model.
|
|
|
|
Args:
|
|
model (nn.Module): torch module
|
|
search_range_m (float): searching range divided by 2^20.
|
|
search_interval (int): searching interval.
|
|
min_chunk_size_m (float, optional): the minimum size of a distributed chunk, divided by 2^20..
|
|
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
|
|
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
|
|
all parameters keep replicated in this mode.
|
|
|
|
Returns:
|
|
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
|
|
"""
|
|
|
|
if memstas is not None:
|
|
param_order = memstas.param_order()
|
|
else:
|
|
# build the param visited order right now
|
|
param_order = OrderedParamGenerator()
|
|
for p in model.parameters():
|
|
param_order.append(p)
|
|
|
|
search_range = round(search_range_m * 1024**2)
|
|
min_chunk_size = round(min_chunk_size_m * 1024**2)
|
|
assert search_range >= 0
|
|
|
|
params_dict = classify_params_by_dp_degree(param_order, process_group)
|
|
size_lcm = np.lcm.reduce(list(params_dict.keys()))
|
|
config_dict: Dict[int, Dict] = dict()
|
|
total_param_size = 0
|
|
|
|
size_dict: Dict[int, List[int]] = dict()
|
|
for dp_degree in params_dict:
|
|
params_list = params_dict[dp_degree]
|
|
size_list = [_tensor_numel(p) for p in params_list]
|
|
group_acc_size = sum(size_list)
|
|
total_param_size += group_acc_size
|
|
|
|
# let small parameters keep gathered in CUDA all the time
|
|
if group_acc_size < min_chunk_size:
|
|
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
|
|
else:
|
|
size_dict[dp_degree] = size_list
|
|
|
|
if filter_exlarge_params:
|
|
_filter_exlarge_params(model, size_dict)
|
|
|
|
max_size = min_chunk_size
|
|
for key in size_dict:
|
|
max_size = max(max_size, max(size_dict[key]))
|
|
start_size = int(math.ceil(max_size / search_interval) * search_interval)
|
|
|
|
min_chunk_waste = float('+inf')
|
|
best_chunk_size = start_size
|
|
|
|
for chunk_size in range(start_size, start_size + search_range + 1, search_interval):
|
|
temp_waste = 0
|
|
for key in size_dict:
|
|
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
|
|
if temp_waste < min_chunk_waste:
|
|
min_chunk_waste = temp_waste
|
|
best_chunk_size = chunk_size
|
|
|
|
# the chunk size needs to be divided by each groups sizes
|
|
best_chunk_size = best_chunk_size + (-best_chunk_size % size_lcm)
|
|
for dp_degree in params_dict:
|
|
if dp_degree in config_dict:
|
|
continue
|
|
config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False)
|
|
|
|
return config_dict, total_param_size, min_chunk_waste
|