[Pipeline inference] Combine kvcache with pipeline inference (#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test
This commit is contained in:
Bin Jia
2023-10-27 16:19:54 +08:00
committed by GitHub
parent c6cd629e7a
commit 1db6727678
19 changed files with 922 additions and 745 deletions

View File

@@ -2,12 +2,15 @@ import pytest
import torch
import torch.distributed as dist
import transformers
from packaging import version
import colossalai
from colossalai.inference.pipeline.engine import PPInferEngine
from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy
from colossalai.inference.pipeline import PPInferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
@@ -24,20 +27,21 @@ for k, v in inputs.items():
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4))
engine = PPInferEngine(
pp_size=pp_size,
model=model,
model_policy=GPT2LMHeadModelPipelinePolicy(),
model_policy=LlamaModelInferPolicy(),
new_length=new_length,
micro_batch_size=micro_batch_size,
)
output = engine.inference([inputs])
output = engine.inference(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
@parameterize("pp_size", [4])
@parameterize("pp_size", [2])
@parameterize("new_length", [4, 8, 16])
@parameterize("micro_batch_size", [1, 4])
@clear_cache_before_run()
@@ -51,11 +55,12 @@ def check_pipeline_inference(rank, world_size, port):
run_pipeline_inference_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=4)
spawn(check_pipeline_inference, nprocs=2)
if __name__ == "__main__":