[Inference] ADD async and sync Api server using FastAPI (#5396)

* add api server

* fix

* add

* add completion service and fix bug

* add generation config

* revise shardformer

* fix bugs

* add docstrings and fix some bugs

* fix bugs and add choices for prompt template
This commit is contained in:
Jianghai
2024-03-01 14:47:36 +08:00
committed by CjhHa1
parent d482922035
commit 69cd7e069d
13 changed files with 789 additions and 25 deletions

View File

@@ -0,0 +1,80 @@
import asyncio
from dataclasses import dataclass
import pytest
from colossalai.inference.core.async_engine import AsyncInferenceEngine
@dataclass
class SequenceTpye:
request_id: int
class MockEngine:
def __init__(self):
self.step_calls = 0
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
async def async_step(self):
self.step_calls += 1
return [SequenceTpye(request_id=self.request_id)] if self.request_id else []
def generate(self, request_id):
self.request_id = request_id
def stop_generating(self):
self.request_id = None
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1
class MockAsyncLLMEngine(AsyncInferenceEngine):
def _init_engine(self, *args, **kwargs):
return MockEngine()
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
await engine.add_request(1, "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1
await engine.add_request(2, "", None)
engine.engine.generate(2)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls == 2
await asyncio.sleep(0)
assert engine.engine.step_calls == 3
engine.engine.stop_generating()
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await engine.add_request(3, "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
if __name__ == "__main__":
test_new_requests_event()

View File

@@ -0,0 +1,77 @@
import pytest
from colossalai.inference.core.async_engine import RequestTracker
from colossalai.inference.struct import Sequence
class SampleEvent:
def __init__(self):
self.flag = False
def set(self):
self.flag = True
def clear(self):
self.flag = False
def test_request_tracker():
tracker = RequestTracker()
tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 1
assert not finished
assert not stream_1.finished
stream_2 = tracker.add_request(2)
stream_3 = tracker.add_request(3)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == 2
assert new[1]["request_id"] == 3
assert not finished
assert not stream_2.finished
assert not stream_3.finished
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request(1)
assert not tracker.new_requests_event.flag
tracker.abort_request(1)
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert 1 in finished
assert not new
assert stream_1.finished
stream_4 = tracker.add_request(4)
tracker.abort_request(4)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert 4 in finished
assert not new
assert stream_4.finished
stream_5 = tracker.add_request(5)
assert tracker.new_requests_event.flag
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert len(finished) == 1
assert 2 in finished
assert len(new) == 1
assert new[0]["request_id"] == 5
assert stream_2.finished
assert not stream_5.finished
if __name__ == "__main__":
test_request_tracker()