[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-25 17:01:48 +08:00
committed by GitHub
parent d7f8db8e21
commit 7cfed5f076
157 changed files with 1353 additions and 8966 deletions

View File

@@ -7,7 +7,6 @@ from transformers import LlamaForCausalLM
from utils import shared_tempdir
import colossalai
from colossalai.testing import skip_if_not_enough_gpus
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.lazy import LazyInitContext
@@ -17,6 +16,7 @@ from colossalai.testing import (
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
skip_if_not_enough_gpus,
spawn,
)
from tests.kit.model_zoo import model_zoo
@@ -52,7 +52,12 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
bert_model.config.save_pretrained(save_directory=pretrained_path)
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
plugin = GeminiPlugin(
**placement_config,
tp_size=tp_size,
enable_all_optimization=enable_all_optimization,
extra_dp_size=extra_dp_size,
)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
@@ -78,7 +83,14 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
criterion = lambda x: x.mean()
enable_all_optimization = True if tp_size > 1 else False
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
plugin = GeminiPlugin(
**placement_config,
precision="fp16",
initial_scale=(2**14),
tp_size=tp_size,
extra_dp_size=extra_dp_size,
enable_all_optimization=enable_all_optimization,
)
booster = Booster(plugin=plugin)
model = model_fn()
@@ -161,8 +173,13 @@ def run_dist(rank, world_size, port):
def test_gemini_ckpIO():
spawn(run_dist, 4)
@pytest.mark.largedist
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_gemini_ckpIO_3d():
spawn(run_dist, 8)
spawn(run_dist, 8)
if __name__ == "__main__":
test_gemini_ckpIO()

View File

@@ -65,9 +65,9 @@ class TorchAdamKernel(AdamKernel):
class FusedAdamKernel(AdamKernel):
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
self.fused_adam = fused_optim.multi_tensor_adam
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
@@ -91,7 +91,7 @@ class FusedAdamKernel(AdamKernel):
class CPUAdamKernel(AdamKernel):
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel import CPUAdamLoader
from colossalai.kernel.kernel_loader import CPUAdamLoader
cpu_optim = CPUAdamLoader().load()

View File

@@ -8,7 +8,7 @@ from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM
from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
from colossalai.kernel import AttnMaskType, ColoAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
DTYPE = [torch.float16, torch.bfloat16, torch.float32]