mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[feat] refactored extension module (#5298)
* [feat] refactored extension module * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish
This commit is contained in:
@@ -7,7 +7,6 @@ from transformers import LlamaForCausalLM
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import skip_if_not_enough_gpus
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
@@ -17,6 +16,7 @@ from colossalai.testing import (
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
skip_if_not_enough_gpus,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
@@ -52,7 +52,12 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
|
||||
plugin = GeminiPlugin(
|
||||
**placement_config,
|
||||
tp_size=tp_size,
|
||||
enable_all_optimization=enable_all_optimization,
|
||||
extra_dp_size=extra_dp_size,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
bert_model, _, _, _, _ = booster.boost(bert_model)
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
@@ -78,7 +83,14 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||
criterion = lambda x: x.mean()
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
|
||||
plugin = GeminiPlugin(
|
||||
**placement_config,
|
||||
precision="fp16",
|
||||
initial_scale=(2**14),
|
||||
tp_size=tp_size,
|
||||
extra_dp_size=extra_dp_size,
|
||||
enable_all_optimization=enable_all_optimization,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = model_fn()
|
||||
@@ -161,8 +173,13 @@ def run_dist(rank, world_size, port):
|
||||
def test_gemini_ckpIO():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO_3d():
|
||||
spawn(run_dist, 8)
|
||||
spawn(run_dist, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gemini_ckpIO()
|
||||
|
Reference in New Issue
Block a user