mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[Pipeline inference] Combine kvcache with pipeline inference (#4938)
* merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test
This commit is contained in:
@@ -2,12 +2,15 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.pipeline.engine import PPInferEngine
|
||||
from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy
|
||||
from colossalai.inference.pipeline import PPInferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
@@ -24,20 +27,21 @@ for k, v in inputs.items():
|
||||
|
||||
|
||||
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
|
||||
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4))
|
||||
|
||||
engine = PPInferEngine(
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
model_policy=GPT2LMHeadModelPipelinePolicy(),
|
||||
model_policy=LlamaModelInferPolicy(),
|
||||
new_length=new_length,
|
||||
micro_batch_size=micro_batch_size,
|
||||
)
|
||||
output = engine.inference([inputs])
|
||||
output = engine.inference(inputs)
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
|
||||
|
||||
|
||||
@parameterize("pp_size", [4])
|
||||
@parameterize("pp_size", [2])
|
||||
@parameterize("new_length", [4, 8, 16])
|
||||
@parameterize("micro_batch_size", [1, 4])
|
||||
@clear_cache_before_run()
|
||||
@@ -51,11 +55,12 @@ def check_pipeline_inference(rank, world_size, port):
|
||||
run_pipeline_inference_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_pipeline_inference():
|
||||
spawn(check_pipeline_inference, nprocs=4)
|
||||
spawn(check_pipeline_inference, nprocs=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user