1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-03 22:18:23 +00:00
ColossalAI/tests/test_infer/test_llama_infer.py
Cuiqing Li bce0f16702
[Feature] The first PR to Add TP inference engine, kv-cache manager and related kernels for our inference system ()
* [infer] Infer/llama demo ()

* add

* add infer example

* finish

* finish

* stash

* fix

* [Kernels]  add inference token attention kernel ()

* add token forward

* fix tests

* fix comments

* add try import triton

* add adapted license

* add tests check

* [Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager  ()

* added _vllm_rms_norm

* change place

* added tests

* added tests

* modify

* adding kernels

* added tests:

* adding kernels

* modify

* added

* updating kernels

* adding tests

* added tests

* kernel change

* submit

* modify

* added

* edit comments

* change name

* change commnets and fix import

* add

* added

* combine codes ()

* [feature] add KV cache manager for llama & bloom inference ()

* add kv cache memory manager

* add stateinfo during inference

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* file dir change

* [Bug FIx] import llama context ops fix ()

* added _vllm_rms_norm

* change place

* added tests

* added tests

* modify

* adding kernels

* added tests:

* adding kernels

* modify

* added

* updating kernels

* adding tests

* added tests

* kernel change

* submit

* modify

* added

* edit comments

* change name

* change commnets and fix import

* add

* added

* fix

* add ops into init.py

* add

* [Infer] Add TPInferEngine and fix file path ()

* add engine for TP inference

* move file path

* update path

* fix TPInferEngine

* remove unused file

* add engine test demo

* revise TPInferEngine

* fix TPInferEngine, add test

* fix

* Add Inference test for llama ()

* add kv cache memory manager

* add stateinfo during inference

* add

* add infer example

* finish

* finish

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* add inference test for llama

* fix conflict

* feature: add some new features for llama engine

* adapt colossalai triton interface

* Change the parent class of llama  policy

* add nvtx

* move llama inference code to tensor_parallel

* fix __init__.py

* rm tensor_parallel

* fix: fix bugs in auto_policy.py

* fix:rm some unused codes

* mv colossalai/tpinference to colossalai/inference/tensor_parallel

* change __init__.py

* save change

* fix engine

* Bug fix: Fix hang

* remove llama_infer_engine.py

---------

Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [infer] Add Bloom inference policy and replaced methods ()

* add bloom inference methods and policy

* enable pass BatchInferState from model forward

* revise bloom infer layers/policies

* add engine for inference (draft)

* add test for bloom infer

* fix bloom infer policy and flow

* revise bloom test

* fix bloom file path

* remove unused codes

* fix bloom modeling

* fix dir typo

* fix trivial

* fix policy

* clean pr

* trivial fix

* Revert "[infer] Add Bloom inference policy and replaced methods ()" ()

This reverts commit 17cfa57140.

* [Doc] Add colossal inference doc ()

* create readme

* add readme.md

* fix typos

* [infer] Add Bloom inference policy and replaced methods ()

* add bloom inference methods and policy

* enable pass BatchInferState from model forward

* revise bloom infer layers/policies

* add engine for inference (draft)

* add test for bloom infer

* fix bloom infer policy and flow

* revise bloom test

* fix bloom file path

* remove unused codes

* fix bloom modeling

* fix dir typo

* fix trivial

* fix policy

* clean pr

* trivial fix

* trivial

* Fix Bugs In Llama Model Forward ()

* add kv cache memory manager

* add stateinfo during inference

* add

* add infer example

* finish

* finish

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* add inference test for llama

* fix conflict

* feature: add some new features for llama engine

* adapt colossalai triton interface

* Change the parent class of llama  policy

* add nvtx

* move llama inference code to tensor_parallel

* fix __init__.py

* rm tensor_parallel

* fix: fix bugs in auto_policy.py

* fix:rm some unused codes

* mv colossalai/tpinference to colossalai/inference/tensor_parallel

* change __init__.py

* save change

* fix engine

* Bug fix: Fix hang

* remove llama_infer_engine.py

* bug fix: fix bugs about infer_state.is_context_stage

* remove pollcies

* fix: delete unused code

* fix: delete unused code

* remove unused coda

* fix conflict

---------

Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [doc] add colossal inference fig ()

* create readme

* add readme.md

* fix typos

* upload fig

* [NFC] fix docstring for colossal inference ()

Fix docstring and comments in kv cache manager and bloom modeling

* fix docstring in llama modeling ()

* [Infer] check import vllm ()

* change import vllm

* import apply_rotary_pos_emb

* change import location

* [DOC] add installation req ()

* add installation req

* fix

* slight change

* remove empty

* [Feature] rms-norm transfer into inference llama.py  ()

* add installation req

* fix

* slight change

* remove empty

* add rmsnorm polciy

* add

* clean codes

* [infer] Fix tp inference engine ()

* fix engine prepare data

* add engine test

* use bloom for testing

* revise on test

* revise on test

* reset shardformer llama ()

* [infer] Fix engine - tensors on different devices ()


* fix diff device in engine

* [codefactor] Feature/colossal inference ()

* code factors

* remove

* change coding ()

* [doc] complete README of colossal inference ()

* complete fig

* Update README.md

* [doc]update readme ()

* update readme

* Update README.md

* bug fix: fix bus in llama and bloom ()

* [BUG FIX]Fix test engine in CI and non-vllm kernels llama forward  ()

* fix tests

* clean

* clean

* fix bugs

* add

* fix llama non-vllm kernels bug

* modify

* clean codes

* [Kernel]Rmsnorm fix ()

* fix tests

* clean

* clean

* fix bugs

* add

* fix llama non-vllm kernels bug

* modify

* clean codes

* add triton rmsnorm

* delete vllm kernel flag

* [Bug Fix]Fix bugs in llama ()

* fix tests

* clean

* clean

* fix bugs

* add

* fix llama non-vllm kernels bug

* modify

* clean codes

* bug fix: remove rotary_positions_ids

---------

Co-authored-by: cuiqing.li <lixx3527@gmail.com>

* [kernel] Add triton layer norm & replace norm for bloom ()

* add layernorm for inference

* add test for layernorm kernel

* add bloom layernorm replacement policy

* trivial: path

* [Infer] Bug fix rotary embedding in llama ()

* fix rotary embedding

* delete print

* fix init seq len bug

* rename pytest

* add benchmark for llama

* refactor codes

* delete useless code

* [bench] Add bloom inference benchmark ()

* add bloom benchmark

* readme - update benchmark res

* trivial - uncomment for testing ()

* [Infer] add check triton and cuda version for tests ()

* fix rotary embedding

* delete print

* fix init seq len bug

* rename pytest

* add benchmark for llama

* refactor codes

* delete useless code

* add check triton and cuda

* Update sharder.py ()

* [Inference] Hot fix some bugs and typos ()

* fix

* fix test

* fix conflicts

* [typo]Comments fix ()

* fallback

* fix commnets

* bug fix: fix some bugs in test_llama and test_bloom ()

* [Infer] delete benchmark in tests and fix bug for llama and bloom ()

* fix rotary embedding

* delete print

* fix init seq len bug

* rename pytest

* add benchmark for llama

* refactor codes

* delete useless code

* add check triton and cuda

* delete benchmark and fix infer bugs

* delete benchmark for tests

* delete useless code

* delete bechmark function in utils

* [Fix] Revise TPInferEngine, inference tests and benchmarks ()

* [Fix] revise TPInferEngine methods and inference tests

* fix llama/bloom infer benchmarks

* fix infer tests

* trivial fix: benchmakrs

* trivial

* trivial: rm print

* modify utils filename for infer ops test ()

* [Infer] Fix TPInferEngine init & inference tests, benchmarks ()

* fix engine funcs

* TPInferEngine: receive shard config in init

* benchmarks: revise TPInferEngine init

* benchmarks: remove pytest decorator

* trivial fix

* use small model for tests

* [NFC] use args for infer benchmarks ()

* revise infer default ()

* [Fix] optimize/shard model in TPInferEngine init ()

* remove using orig model in engine

* revise inference tests

* trivial: rename

---------

Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
2023-09-12 01:22:56 +08:00

85 lines
3.0 KiB
Python

import os
import warnings
import pytest
import torch
from packaging import version
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
TPSIZE = 2
BATCH_SIZE = 8
MAX_INPUT_LEN = 12
MAX_OUTPUT_LEN = 100
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
self.config.head_dim_))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return
@parameterize('test_config', [{
'tp_size': TPSIZE,
}])
def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama_for_casual_lm')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
orig_model = model_fn()
init_to_get_rotary(orig_model.model, base=10000)
orig_model = orig_model.half()
data = data_gen_fn()
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(data, **generate_kwargs)
assert outputs is not None
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, TPSIZE)
if __name__ == "__main__":
test_llama()