mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 03:43:01 +00:00
* [infer] Infer/llama demo (#4503)
* add
* add infer example
* finish
* finish
* stash
* fix
* [Kernels] add inference token attention kernel (#4505)
* add token forward
* fix tests
* fix comments
* add try import triton
* add adapted license
* add tests check
* [Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager (#4485)
* added _vllm_rms_norm
* change place
* added tests
* added tests
* modify
* adding kernels
* added tests:
* adding kernels
* modify
* added
* updating kernels
* adding tests
* added tests
* kernel change
* submit
* modify
* added
* edit comments
* change name
* change commnets and fix import
* add
* added
* combine codes (#4509)
* [feature] add KV cache manager for llama & bloom inference (#4495)
* add kv cache memory manager
* add stateinfo during inference
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* file dir change
* [Bug FIx] import llama context ops fix (#4524)
* added _vllm_rms_norm
* change place
* added tests
* added tests
* modify
* adding kernels
* added tests:
* adding kernels
* modify
* added
* updating kernels
* adding tests
* added tests
* kernel change
* submit
* modify
* added
* edit comments
* change name
* change commnets and fix import
* add
* added
* fix
* add ops into init.py
* add
* [Infer] Add TPInferEngine and fix file path (#4532)
* add engine for TP inference
* move file path
* update path
* fix TPInferEngine
* remove unused file
* add engine test demo
* revise TPInferEngine
* fix TPInferEngine, add test
* fix
* Add Inference test for llama (#4508)
* add kv cache memory manager
* add stateinfo during inference
* add
* add infer example
* finish
* finish
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* add inference test for llama
* fix conflict
* feature: add some new features for llama engine
* adapt colossalai triton interface
* Change the parent class of llama policy
* add nvtx
* move llama inference code to tensor_parallel
* fix __init__.py
* rm tensor_parallel
* fix: fix bugs in auto_policy.py
* fix:rm some unused codes
* mv colossalai/tpinference to colossalai/inference/tensor_parallel
* change __init__.py
* save change
* fix engine
* Bug fix: Fix hang
* remove llama_infer_engine.py
---------
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
* [infer] Add Bloom inference policy and replaced methods (#4512)
* add bloom inference methods and policy
* enable pass BatchInferState from model forward
* revise bloom infer layers/policies
* add engine for inference (draft)
* add test for bloom infer
* fix bloom infer policy and flow
* revise bloom test
* fix bloom file path
* remove unused codes
* fix bloom modeling
* fix dir typo
* fix trivial
* fix policy
* clean pr
* trivial fix
* Revert "[infer] Add Bloom inference policy and replaced methods (#4512)" (#4552)
This reverts commit 17cfa57140
.
* [Doc] Add colossal inference doc (#4549)
* create readme
* add readme.md
* fix typos
* [infer] Add Bloom inference policy and replaced methods (#4553)
* add bloom inference methods and policy
* enable pass BatchInferState from model forward
* revise bloom infer layers/policies
* add engine for inference (draft)
* add test for bloom infer
* fix bloom infer policy and flow
* revise bloom test
* fix bloom file path
* remove unused codes
* fix bloom modeling
* fix dir typo
* fix trivial
* fix policy
* clean pr
* trivial fix
* trivial
* Fix Bugs In Llama Model Forward (#4550)
* add kv cache memory manager
* add stateinfo during inference
* add
* add infer example
* finish
* finish
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* add inference test for llama
* fix conflict
* feature: add some new features for llama engine
* adapt colossalai triton interface
* Change the parent class of llama policy
* add nvtx
* move llama inference code to tensor_parallel
* fix __init__.py
* rm tensor_parallel
* fix: fix bugs in auto_policy.py
* fix:rm some unused codes
* mv colossalai/tpinference to colossalai/inference/tensor_parallel
* change __init__.py
* save change
* fix engine
* Bug fix: Fix hang
* remove llama_infer_engine.py
* bug fix: fix bugs about infer_state.is_context_stage
* remove pollcies
* fix: delete unused code
* fix: delete unused code
* remove unused coda
* fix conflict
---------
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
* [doc] add colossal inference fig (#4554)
* create readme
* add readme.md
* fix typos
* upload fig
* [NFC] fix docstring for colossal inference (#4555)
Fix docstring and comments in kv cache manager and bloom modeling
* fix docstring in llama modeling (#4557)
* [Infer] check import vllm (#4559)
* change import vllm
* import apply_rotary_pos_emb
* change import location
* [DOC] add installation req (#4561)
* add installation req
* fix
* slight change
* remove empty
* [Feature] rms-norm transfer into inference llama.py (#4563)
* add installation req
* fix
* slight change
* remove empty
* add rmsnorm polciy
* add
* clean codes
* [infer] Fix tp inference engine (#4564)
* fix engine prepare data
* add engine test
* use bloom for testing
* revise on test
* revise on test
* reset shardformer llama (#4569)
* [infer] Fix engine - tensors on different devices (#4570)
* fix diff device in engine
* [codefactor] Feature/colossal inference (#4579)
* code factors
* remove
* change coding (#4581)
* [doc] complete README of colossal inference (#4585)
* complete fig
* Update README.md
* [doc]update readme (#4586)
* update readme
* Update README.md
* bug fix: fix bus in llama and bloom (#4588)
* [BUG FIX]Fix test engine in CI and non-vllm kernels llama forward (#4592)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* [Kernel]Rmsnorm fix (#4598)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* add triton rmsnorm
* delete vllm kernel flag
* [Bug Fix]Fix bugs in llama (#4601)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* bug fix: remove rotary_positions_ids
---------
Co-authored-by: cuiqing.li <lixx3527@gmail.com>
* [kernel] Add triton layer norm & replace norm for bloom (#4609)
* add layernorm for inference
* add test for layernorm kernel
* add bloom layernorm replacement policy
* trivial: path
* [Infer] Bug fix rotary embedding in llama (#4608)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* [bench] Add bloom inference benchmark (#4621)
* add bloom benchmark
* readme - update benchmark res
* trivial - uncomment for testing (#4622)
* [Infer] add check triton and cuda version for tests (#4627)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* add check triton and cuda
* Update sharder.py (#4629)
* [Inference] Hot fix some bugs and typos (#4632)
* fix
* fix test
* fix conflicts
* [typo]Comments fix (#4633)
* fallback
* fix commnets
* bug fix: fix some bugs in test_llama and test_bloom (#4635)
* [Infer] delete benchmark in tests and fix bug for llama and bloom (#4636)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* add check triton and cuda
* delete benchmark and fix infer bugs
* delete benchmark for tests
* delete useless code
* delete bechmark function in utils
* [Fix] Revise TPInferEngine, inference tests and benchmarks (#4642)
* [Fix] revise TPInferEngine methods and inference tests
* fix llama/bloom infer benchmarks
* fix infer tests
* trivial fix: benchmakrs
* trivial
* trivial: rm print
* modify utils filename for infer ops test (#4657)
* [Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670)
* fix engine funcs
* TPInferEngine: receive shard config in init
* benchmarks: revise TPInferEngine init
* benchmarks: remove pytest decorator
* trivial fix
* use small model for tests
* [NFC] use args for infer benchmarks (#4674)
* revise infer default (#4683)
* [Fix] optimize/shard model in TPInferEngine init (#4684)
* remove using orig model in engine
* revise inference tests
* trivial: rename
---------
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
95 lines
3.9 KiB
Python
95 lines
3.9 KiB
Python
from itertools import accumulate
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from packaging import version
|
|
from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM
|
|
from transformers.tokenization_utils_base import BatchEncoding
|
|
|
|
import colossalai
|
|
from colossalai.inference.tensor_parallel import TPInferEngine
|
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
|
from colossalai.logging import disable_existing_loggers
|
|
from colossalai.shardformer import ShardConfig
|
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
|
|
|
TP_SIZE = 2
|
|
MAX_BATCH_SIZE = 4
|
|
MAX_INPUT_LEN = 16
|
|
MAX_OUTPUT_LEN = 8
|
|
|
|
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
|
|
|
|
|
|
@parameterize('test_config', [{
|
|
'tp_size': TP_SIZE,
|
|
}])
|
|
def run(test_config):
|
|
model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
|
|
model = BloomForCausalLM(model_config)
|
|
model = model.half()
|
|
model.to(torch.cuda.current_device())
|
|
|
|
# 1. check TPInferEngine init and model optimization
|
|
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
|
|
inference_only=True)
|
|
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
|
|
|
assert infer_engine.cache_manager is not None
|
|
assert infer_engine.tp_size == TP_SIZE
|
|
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE
|
|
|
|
# 2. check data preparation
|
|
input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970],
|
|
[80540, 15473, 3331, 11970], [80540, 15473]]
|
|
batch_size = len(input_ids_list)
|
|
max_seq_len = max(len(li) for li in input_ids_list)
|
|
attention_mask = [[0] * max_seq_len for _ in range(batch_size)]
|
|
for i, li in enumerate(input_ids_list):
|
|
attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))]
|
|
data = dict(input_ids=input_ids_list, attention_mask=attention_mask)
|
|
inputs_batch_encoding = BatchEncoding(data=data)
|
|
seq_lengths = [len(li) for li in input_ids_list]
|
|
start_loc = list(accumulate([0] + seq_lengths[:-1]))
|
|
seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32)
|
|
start_loc = torch.tensor(start_loc, dtype=torch.int32)
|
|
# input token id list as inputs
|
|
batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding)
|
|
# BatchEncoding as inputs
|
|
batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list)
|
|
|
|
assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size
|
|
assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len)
|
|
|
|
# The following tests are discarded for now, and will be reused after all features are added
|
|
# assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
|
|
# assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
|
|
# assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
|
|
# assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)
|
|
|
|
# 3. check optimized model generate
|
|
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))
|
|
generate_kwargs = dict(do_sample=False)
|
|
infer_engine.generate(input_ids, **generate_kwargs)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def check_engine(rank, world_size, port):
|
|
disable_existing_loggers()
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
run()
|
|
|
|
|
|
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
|
@pytest.mark.dist
|
|
@rerun_if_address_is_in_use()
|
|
@clear_cache_before_run()
|
|
def test_engine():
|
|
spawn(check_engine, TP_SIZE)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_engine()
|