mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[Inference] Dynamic Batching Inference, online and offline (#4953)
* [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commitfbf3c09e67
. * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250
. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commitfced140250
. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250
. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
@@ -52,7 +52,6 @@ def run_chatglm2_test(test_config):
|
||||
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
}
|
||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
|
||||
|
14
tests/test_infer/test_dynamic_batching/config.yaml
Normal file
14
tests/test_infer/test_dynamic_batching/config.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
engine_config:
|
||||
model: MODEL_PATH
|
||||
tensor_parallel_size: 1
|
||||
max_batch_size: 2
|
||||
max_input_len: 1024
|
||||
max_output_len: 512
|
||||
# config for app router deployment
|
||||
# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig.
|
||||
router_config:
|
||||
max_total_token_num: 4096
|
||||
batch_max_tokens: 4096
|
||||
disable_log_stats: False
|
||||
log_stats_interval: 10
|
||||
model: MODEL_PATH
|
61
tests/test_infer/test_dynamic_batching/test_async_engine.py
Normal file
61
tests/test_infer/test_dynamic_batching/test_async_engine.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.async_engine import Async_Engine
|
||||
from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
PATH = "config.yaml"
|
||||
|
||||
|
||||
def run_async_engine(path: str):
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
|
||||
config = RayInitConfig.from_yaml_path(path)
|
||||
engine_config = config.engine_config_data
|
||||
model = engine_config.model
|
||||
if model is None or not os.path.exists(model):
|
||||
return
|
||||
|
||||
prompt = "Introduce some landmarks in London.\n The Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10"
|
||||
sampling_params = SamplingParams()
|
||||
asyncio.run(asy_for_loop_test(config, prompt, sampling_params))
|
||||
|
||||
|
||||
async def get_result(engine, prompt, sampling_params):
|
||||
request_id = str(uuid.uuid4().hex)
|
||||
results = engine.generate(request_id, prompt, sampling_params)
|
||||
async for result in results:
|
||||
# print(result)
|
||||
assert result is not None
|
||||
|
||||
|
||||
async def asy_for_loop_test(config, prompt, sampling_params):
|
||||
router_config = config.router_config_data
|
||||
engine_config = config.engine_config_data
|
||||
engine = Async_Engine(router_config=router_config, engine_config=engine_config)
|
||||
for i in range(10):
|
||||
print("in for loop", i)
|
||||
await get_result(engine, prompt, sampling_params)
|
||||
|
||||
|
||||
def check_async_engine(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_async_engine(PATH)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_async_engine():
|
||||
spawn(check_async_engine, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_async_engine()
|
@@ -0,0 +1,95 @@
|
||||
import pytest
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.dynamic_batching.io_struct import Req
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.inference.manager import DynamicBatchManager
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
TP_SIZE = 1
|
||||
BATCH_SIZE = 2
|
||||
MAX_INPUT_LEN = 48
|
||||
MAX_OUTPUT_LEN = 256
|
||||
|
||||
|
||||
def run():
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
req1 = Req(0, [1], sampling_params)
|
||||
req2 = Req(1, [2], sampling_params)
|
||||
req3 = Req(2, [3], sampling_params)
|
||||
# req 1-3 are initiliazed as token forward requests
|
||||
req4 = Req(3, [10, 10, 10, 9, 1], sampling_params)
|
||||
waiting_list = []
|
||||
waiting_list.append(req1)
|
||||
waiting_list.append(req2)
|
||||
waiting_list.append(req3)
|
||||
|
||||
# init model and tp engine
|
||||
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
dynamic_batch_manager = DynamicBatchManager(
|
||||
tp_engine=infer_engine,
|
||||
max_total_token_num=640,
|
||||
batch_max_tokens=608,
|
||||
eos_id=0,
|
||||
log_stats=False,
|
||||
log_stats_interval=10,
|
||||
waiting_req_list=waiting_list,
|
||||
model="llama",
|
||||
)
|
||||
before_add = len(dynamic_batch_manager.req_queue)
|
||||
|
||||
# test add req function
|
||||
dynamic_batch_manager.add_req(req4.request_id, req4.prompt_ids, req4.sample_params)
|
||||
assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1
|
||||
|
||||
# test abort function
|
||||
dynamic_batch_manager.abort(req4.request_id)
|
||||
assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True
|
||||
|
||||
# test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested
|
||||
batch = dynamic_batch_manager.req_queue.generate_new_batch()
|
||||
assert len(batch) == 2
|
||||
|
||||
dynamic_batch_manager._init_batch(batch)
|
||||
assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None
|
||||
|
||||
batch.reqs[0].has_generate_finished = True
|
||||
# filter one finished
|
||||
batch.filter_finished()
|
||||
dynamic_batch_manager._filter_batch(batch)
|
||||
assert len(dynamic_batch_manager.engine.cache) == 1
|
||||
|
||||
# test merge batch
|
||||
new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch)
|
||||
assert len(new_batch) == 1
|
||||
dynamic_batch_manager._init_batch(new_batch)
|
||||
dynamic_batch_manager._merge_batch(batch, new_batch)
|
||||
|
||||
assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2
|
||||
|
||||
|
||||
def check_dynamic_batching_manager(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_dynamic_batching_manager():
|
||||
spawn(check_dynamic_batching_manager, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dynamic_batching_manager()
|
@@ -0,0 +1,84 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.dynamic_batching.io_struct import Req
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.inference.manager import start_dynamic_batching
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
TP_SIZE = 1
|
||||
MAX_BATCH_SIZE = 2
|
||||
MAX_INPUT_LEN = 5
|
||||
MAX_OUTPUT_LEN = 16
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
@dataclass
|
||||
class args:
|
||||
max_total_token_num: int
|
||||
batch_max_tokens: int
|
||||
model: str
|
||||
eos_id: int
|
||||
disable_log_stats: bool
|
||||
log_stats_interval: int
|
||||
|
||||
|
||||
def run():
|
||||
arg = args(
|
||||
max_total_token_num=42,
|
||||
model="llama",
|
||||
batch_max_tokens=42,
|
||||
eos_id=0,
|
||||
disable_log_stats=False,
|
||||
log_stats_interval=10,
|
||||
)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
|
||||
req2 = Req(1, [10, 10, 10, 10, 10], sampling_params)
|
||||
req3 = Req(2, [0, 0, 10, 10, 10], sampling_params)
|
||||
req4 = Req(3, [0, 0, 10, 10, 10], sampling_params)
|
||||
|
||||
waiting_list = []
|
||||
waiting_list.append(req1)
|
||||
waiting_list.append(req2)
|
||||
waiting_list.append(req3)
|
||||
waiting_list.append(req4)
|
||||
|
||||
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
|
||||
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
||||
|
||||
ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params)
|
||||
for result in ans_gen:
|
||||
assert result is not None
|
||||
|
||||
|
||||
def check_dynamic_forward(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_dynamic_batching():
|
||||
spawn(check_dynamic_forward, TP_SIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dynamic_batching()
|
66
tests/test_infer/test_dynamic_batching/test_ray_dist.py
Normal file
66
tests/test_infer/test_dynamic_batching/test_ray_dist.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.dynamic_batching.ray_dist_init import Driver
|
||||
from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
PATH = "config.yaml"
|
||||
|
||||
|
||||
def run_ray_dist(path: str):
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
config = RayInitConfig.from_yaml_path(path)
|
||||
router_config = config.router_config_data
|
||||
engine_config = config.engine_config_data
|
||||
model = engine_config.model
|
||||
if model is None or not os.path.exists(model):
|
||||
return
|
||||
driver = Driver(router_config=router_config, engine_config=engine_config)
|
||||
prompt = "Introduce some landmarks in Beijing"
|
||||
|
||||
request_id = str(uuid.uuid4().hex)
|
||||
sampling_params = SamplingParams()
|
||||
print("sampling_params: ", sampling_params)
|
||||
|
||||
async def get_result(request_id, prompt, sampling_params):
|
||||
return await driver.async_generate(request_id, prompt, sampling_params)
|
||||
|
||||
for test_async in [True, False]:
|
||||
if test_async:
|
||||
print("test_async: ", test_async)
|
||||
result = asyncio.run(get_result(request_id, prompt, sampling_params))
|
||||
assert result is not None
|
||||
print("result: ", result)
|
||||
else:
|
||||
print("test_async: ", test_async)
|
||||
result = driver.generate(request_id, prompt, sampling_params)
|
||||
assert result is not None
|
||||
print("result: ", result)
|
||||
|
||||
is_running = None
|
||||
is_running = driver.is_running()
|
||||
assert is_running is not None
|
||||
print("is_running: ", is_running)
|
||||
|
||||
|
||||
def check_ray_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_ray_dist(PATH)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_ray_dist():
|
||||
spawn(check_ray_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ray_dist()
|
Reference in New Issue
Block a user