mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[feat] refactored extension module (#5298)
* [feat] refactored extension module * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish
This commit is contained in:
@@ -7,7 +7,6 @@ from transformers import LlamaForCausalLM
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import skip_if_not_enough_gpus
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
@@ -17,6 +16,7 @@ from colossalai.testing import (
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
skip_if_not_enough_gpus,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
@@ -52,7 +52,12 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
|
||||
plugin = GeminiPlugin(
|
||||
**placement_config,
|
||||
tp_size=tp_size,
|
||||
enable_all_optimization=enable_all_optimization,
|
||||
extra_dp_size=extra_dp_size,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
bert_model, _, _, _, _ = booster.boost(bert_model)
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
@@ -78,7 +83,14 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||
criterion = lambda x: x.mean()
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
|
||||
plugin = GeminiPlugin(
|
||||
**placement_config,
|
||||
precision="fp16",
|
||||
initial_scale=(2**14),
|
||||
tp_size=tp_size,
|
||||
extra_dp_size=extra_dp_size,
|
||||
enable_all_optimization=enable_all_optimization,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = model_fn()
|
||||
@@ -161,8 +173,13 @@ def run_dist(rank, world_size, port):
|
||||
def test_gemini_ckpIO():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO_3d():
|
||||
spawn(run_dist, 8)
|
||||
spawn(run_dist, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gemini_ckpIO()
|
||||
|
@@ -65,9 +65,9 @@ class TorchAdamKernel(AdamKernel):
|
||||
class FusedAdamKernel(AdamKernel):
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
|
||||
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
fused_optim = FusedOptimizerLoader().load()
|
||||
self.fused_adam = fused_optim.multi_tensor_adam
|
||||
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
@@ -91,7 +91,7 @@ class FusedAdamKernel(AdamKernel):
|
||||
class CPUAdamKernel(AdamKernel):
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
from colossalai.kernel import CPUAdamLoader
|
||||
from colossalai.kernel.kernel_loader import CPUAdamLoader
|
||||
|
||||
cpu_optim = CPUAdamLoader().load()
|
||||
|
||||
|
@@ -8,7 +8,7 @@ from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
||||
from colossalai.kernel import AttnMaskType, ColoAttention
|
||||
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
|
||||
|
||||
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
|
Reference in New Issue
Block a user