mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
@@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
||||
return left + acc
|
||||
|
||||
|
||||
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
|
||||
def _tensor_numel(local_param: ColoParameter) -> int:
|
||||
"""_tensor_numel
|
||||
|
||||
Get the number of elements of a tensor.
|
||||
@@ -71,15 +72,12 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
|
||||
Returns:
|
||||
int: the number of elements.
|
||||
"""
|
||||
if strict_ddp_flag and type(local_param) is ColoParameter:
|
||||
return local_param.numel_global()
|
||||
else:
|
||||
# if local_param is not ColoParameter, we assume it's replicated
|
||||
return local_param.numel()
|
||||
# TODO(ver217): support dtensor here
|
||||
return local_param.numel()
|
||||
|
||||
|
||||
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
||||
strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]:
|
||||
process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
|
||||
"""classify_params_by_dp_degree
|
||||
|
||||
Classify the parameters by their dp degree
|
||||
@@ -97,13 +95,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
||||
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
if is_ddp_ignored(param):
|
||||
continue
|
||||
|
||||
if strict_ddp_flag or type(param) is not ColoParameter:
|
||||
# if model is not initialized with ColoInitContext, we assume it's replicated
|
||||
# TODO(ver217): integrate DTensor
|
||||
param_key = dist.get_world_size()
|
||||
else:
|
||||
param_key = param.process_group.dp_world_size()
|
||||
param_key = dist.get_world_size(process_group)
|
||||
|
||||
if param_key not in params_dict:
|
||||
params_dict[param_key] = []
|
||||
@@ -119,6 +111,7 @@ def search_chunk_configuration(
|
||||
min_chunk_size_m: float = 32,
|
||||
filter_exlarge_params: bool = True,
|
||||
strict_ddp_flag: bool = False,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
||||
"""search_chunk_configuration
|
||||
|
||||
@@ -149,7 +142,7 @@ def search_chunk_configuration(
|
||||
min_chunk_size = round(min_chunk_size_m * 1024**2)
|
||||
assert search_range >= 0
|
||||
|
||||
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
|
||||
params_dict = classify_params_by_dp_degree(param_order, process_group)
|
||||
size_lcm = np.lcm.reduce(list(params_dict.keys()))
|
||||
config_dict: Dict[int, Dict] = dict()
|
||||
total_param_size = 0
|
||||
@@ -157,7 +150,7 @@ def search_chunk_configuration(
|
||||
size_dict: Dict[int, List[int]] = dict()
|
||||
for dp_degree in params_dict:
|
||||
params_list = params_dict[dp_degree]
|
||||
size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list]
|
||||
size_list = [_tensor_numel(p) for p in params_list]
|
||||
group_acc_size = sum(size_list)
|
||||
total_param_size += group_acc_size
|
||||
|
||||
|
Reference in New Issue
Block a user