Files
ColossalAI/colossalai/inference/dynamic_batching/req_queue.py
Jianghai cf579ff46d [Inference] Dynamic Batching Inference, online and offline (#4953)
* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09e67.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
2023-10-30 10:52:19 +08:00

74 lines
2.7 KiB
Python

# Adapted from https://github.com/ModelTC/lightllm
import uuid
from typing import List
import numpy as np
from .io_struct import Batch, Req
class ReqQueue:
def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None:
self.max_total_tokens = max_total_tokens
assert batch_max_tokens is not None
self.batch_max_tokens = batch_max_tokens
self.running_max_req_size = running_max_req_size
self.waiting_req_list: List[Req] = waiting_req_list
def append(self, req):
self.waiting_req_list.append(req)
return
def _init_cache_list(self, current_batch: Batch):
if current_batch is not None:
self.cache_len_list = [
(req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1)
for req in current_batch.reqs
]
else:
self.cache_len_list = []
# @calculate_time(show=True, min_cost_ms=0.1)
def _can_add_new_req(self, req):
self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis
self.cache_len_list.sort(key=lambda x: -x[1])
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
# assert left_out_len_array.min() >= 0
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
cum_run_len_array = np.cumsum(has_run_len_array)
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
# NOTE: change here < to <=
return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size
def generate_new_batch(self, current_batch: Batch = None):
if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size:
return None
self._init_cache_list(current_batch)
can_run_list = []
new_batch_total_tokens = 0
aborted_count = 0
for req in self.waiting_req_list:
flag = self._can_add_new_req(req)
if req.aborted:
aborted_count += 1
continue
if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens:
can_run_list.append(req)
new_batch_total_tokens += req.input_len
else:
break
if len(can_run_list) != 0:
new_batch = Batch(uuid.uuid4().hex, can_run_list)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
else:
return None
def __len__(self):
return self.waiting_req_list.__len__()