[gemini] improve compatibility and add static placement policy (#4479)

* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
This commit is contained in:
Hongxin Liu
2023-08-24 09:29:25 +08:00
committed by GitHub
parent 285fe7ba71
commit 27061426f7
82 changed files with 1008 additions and 4036 deletions

View File

@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.utils import is_ddp_ignored
@@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
def _tensor_numel(local_param: ColoParameter) -> int:
"""_tensor_numel
Get the number of elements of a tensor.
@@ -71,15 +72,12 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
Returns:
int: the number of elements.
"""
if strict_ddp_flag and type(local_param) is ColoParameter:
return local_param.numel_global()
else:
# if local_param is not ColoParameter, we assume it's replicated
return local_param.numel()
# TODO(ver217): support dtensor here
return local_param.numel()
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]:
process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
@@ -97,13 +95,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if is_ddp_ignored(param):
continue
if strict_ddp_flag or type(param) is not ColoParameter:
# if model is not initialized with ColoInitContext, we assume it's replicated
# TODO(ver217): integrate DTensor
param_key = dist.get_world_size()
else:
param_key = param.process_group.dp_world_size()
param_key = dist.get_world_size(process_group)
if param_key not in params_dict:
params_dict[param_key] = []
@@ -119,6 +111,7 @@ def search_chunk_configuration(
min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False,
process_group: Optional[ProcessGroup] = None,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
"""search_chunk_configuration
@@ -149,7 +142,7 @@ def search_chunk_configuration(
min_chunk_size = round(min_chunk_size_m * 1024**2)
assert search_range >= 0
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
params_dict = classify_params_by_dp_degree(param_order, process_group)
size_lcm = np.lcm.reduce(list(params_dict.keys()))
config_dict: Dict[int, Dict] = dict()
total_param_size = 0
@@ -157,7 +150,7 @@ def search_chunk_configuration(
size_dict: Dict[int, List[int]] = dict()
for dp_degree in params_dict:
params_list = params_dict[dp_degree]
size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list]
size_list = [_tensor_numel(p) for p in params_list]
group_acc_size = sum(size_list)
total_param_size += group_acc_size