mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-29 04:05:35 +00:00
* [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions <github-actions@github.com> * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions <github-actions@github.com> * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com> * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commitfbf3c09e67
. * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250
. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commitfced140250
. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250
. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit479900c139
. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> Co-authored-by: Xu Kai <xukai16@foxamil.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
161 lines
6.8 KiB
Python
161 lines
6.8 KiB
Python
import os
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
from transformers import LlamaForCausalLM
|
|
from utils import shared_tempdir
|
|
|
|
import colossalai
|
|
from colossalai.booster import Booster
|
|
from colossalai.booster.plugin import GeminiPlugin
|
|
from colossalai.lazy import LazyInitContext
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
from colossalai.testing import (
|
|
check_state_dict_equal,
|
|
clear_cache_before_run,
|
|
parameterize,
|
|
rerun_if_address_is_in_use,
|
|
spawn,
|
|
)
|
|
from tests.kit.model_zoo import model_zoo
|
|
|
|
MODEL_PLACEMENT_CONFIGS = [
|
|
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
|
|
{"placement_policy": "static", "shard_param_frac": 1.0}, # zero3
|
|
{"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half
|
|
]
|
|
|
|
OPTIM_PLACEMENT_CONFIGS = [
|
|
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2
|
|
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0}, # zero2-offload
|
|
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half
|
|
]
|
|
|
|
|
|
@clear_cache_before_run()
|
|
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
|
|
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
|
|
@parameterize("use_safetensors", [False, True])
|
|
@parameterize("enable_tensor_parallelism", [True, False])
|
|
@parameterize("tp_size", [2])
|
|
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int):
|
|
from transformers import BertForSequenceClassification
|
|
|
|
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
|
bert_model = model_fn()
|
|
enable_all_optimization = True if enable_tensor_parallelism else False
|
|
|
|
with shared_tempdir() as tempdir:
|
|
pretrained_path = os.path.join(tempdir, "pretrained")
|
|
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
|
|
|
plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
|
|
booster = Booster(plugin=plugin)
|
|
bert_model, _, _, _, _ = booster.boost(bert_model)
|
|
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
|
|
|
booster.save_model(
|
|
bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors
|
|
)
|
|
dist.barrier()
|
|
|
|
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
|
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
|
|
|
|
|
|
@clear_cache_before_run()
|
|
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
|
|
@parameterize("shard", [True, False])
|
|
@parameterize("model_name", ["transformers_gpt"])
|
|
@parameterize("size_per_shard", [32])
|
|
@parameterize("enable_tensor_parallelism", [True, False])
|
|
@parameterize("tp_size", [2])
|
|
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int):
|
|
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
|
criterion = lambda x: x.mean()
|
|
enable_all_optimization = True if enable_tensor_parallelism else False
|
|
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
|
|
booster = Booster(plugin=plugin)
|
|
|
|
model = model_fn()
|
|
new_model = model_fn()
|
|
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
|
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
|
|
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
|
|
|
data = data_gen_fn()
|
|
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
|
|
output = model(**data)
|
|
output = output_transform_fn(output)
|
|
output_key = list(output.keys())[0]
|
|
loss = criterion(output[output_key])
|
|
|
|
booster.backward(loss, optimizer)
|
|
optimizer.step()
|
|
|
|
with shared_tempdir() as tempdir:
|
|
model_ckpt_path = f"{tempdir}/model"
|
|
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
|
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
|
|
|
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
|
dist.barrier()
|
|
|
|
booster.load_model(new_model, model_ckpt_path)
|
|
check_state_dict_equal(
|
|
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
|
|
)
|
|
|
|
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
|
check_state_dict_equal(
|
|
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
|
|
)
|
|
|
|
# Check the new model/optimizer can successfully run.
|
|
data = data_gen_fn()
|
|
data = {
|
|
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
|
|
}
|
|
output = new_model(**data)
|
|
output = output_transform_fn(output)
|
|
output_key = list(output.keys())[0]
|
|
loss = criterion(output[output_key])
|
|
booster.backward(loss, new_optimizer)
|
|
new_optimizer.step()
|
|
booster.save_model(new_model, model_ckpt_path, shard=shard)
|
|
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
|
|
|
|
|
|
def exam_lazy_from_pretrained():
|
|
llama_path = os.environ["LLAMA_PATH"]
|
|
plugin = GeminiPlugin()
|
|
booster = Booster(plugin=plugin)
|
|
orig_model = LlamaForCausalLM.from_pretrained(llama_path)
|
|
orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()}
|
|
with LazyInitContext():
|
|
model = LlamaForCausalLM.from_pretrained(llama_path)
|
|
model, *_ = booster.boost(model)
|
|
with shared_tempdir() as tempdir:
|
|
save_path = os.path.join(tempdir, "model.pt")
|
|
booster.save_model(model, save_path, shard=False)
|
|
dist.barrier()
|
|
state_dict = torch.load(save_path, map_location="cpu")
|
|
check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
config = {}
|
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
exam_state_dict()
|
|
exam_state_dict_with_origin()
|
|
exam_lazy_from_pretrained()
|
|
|
|
|
|
@pytest.mark.dist
|
|
@pytest.mark.parametrize("world_size", [4])
|
|
@rerun_if_address_is_in_use()
|
|
def test_gemini_ckpIO(world_size):
|
|
spawn(run_dist, world_size)
|