[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-25 17:01:48 +08:00
committed by GitHub
parent d7f8db8e21
commit 7cfed5f076
157 changed files with 1353 additions and 8966 deletions

View File

@@ -7,7 +7,6 @@ from transformers import LlamaForCausalLM
from utils import shared_tempdir
import colossalai
from colossalai.testing import skip_if_not_enough_gpus
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.lazy import LazyInitContext
@@ -17,6 +16,7 @@ from colossalai.testing import (
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
skip_if_not_enough_gpus,
spawn,
)
from tests.kit.model_zoo import model_zoo
@@ -52,7 +52,12 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
bert_model.config.save_pretrained(save_directory=pretrained_path)
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
plugin = GeminiPlugin(
**placement_config,
tp_size=tp_size,
enable_all_optimization=enable_all_optimization,
extra_dp_size=extra_dp_size,
)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
@@ -78,7 +83,14 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
criterion = lambda x: x.mean()
enable_all_optimization = True if tp_size > 1 else False
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
plugin = GeminiPlugin(
**placement_config,
precision="fp16",
initial_scale=(2**14),
tp_size=tp_size,
extra_dp_size=extra_dp_size,
enable_all_optimization=enable_all_optimization,
)
booster = Booster(plugin=plugin)
model = model_fn()
@@ -161,8 +173,13 @@ def run_dist(rank, world_size, port):
def test_gemini_ckpIO():
spawn(run_dist, 4)
@pytest.mark.largedist
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_gemini_ckpIO_3d():
spawn(run_dist, 8)
spawn(run_dist, 8)
if __name__ == "__main__":
test_gemini_ckpIO()