ColossalAI/tests/test_infer/test_async_engine/test_request_tracker.py
Jianghai 69cd7e069d [Inference] ADD async and sync Api server using FastAPI (#5396)
* add api server

* fix

* add

* add completion service and fix bug

* add generation config

* revise shardformer

* fix bugs

* add docstrings and fix some bugs

* fix bugs and add choices for prompt template
2024-05-08 15:18:28 +00:00

78 lines
2.2 KiB
Python

import pytest
from colossalai.inference.core.async_engine import RequestTracker
from colossalai.inference.struct import Sequence
class SampleEvent:
def __init__(self):
self.flag = False
def set(self):
self.flag = True
def clear(self):
self.flag = False
def test_request_tracker():
tracker = RequestTracker()
tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 1
assert not finished
assert not stream_1.finished
stream_2 = tracker.add_request(2)
stream_3 = tracker.add_request(3)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == 2
assert new[1]["request_id"] == 3
assert not finished
assert not stream_2.finished
assert not stream_3.finished
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request(1)
assert not tracker.new_requests_event.flag
tracker.abort_request(1)
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert 1 in finished
assert not new
assert stream_1.finished
stream_4 = tracker.add_request(4)
tracker.abort_request(4)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert 4 in finished
assert not new
assert stream_4.finished
stream_5 = tracker.add_request(5)
assert tracker.new_requests_event.flag
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert len(finished) == 1
assert 2 in finished
assert len(new) == 1
assert new[0]["request_id"] == 5
assert stream_2.finished
assert not stream_5.finished
if __name__ == "__main__":
test_request_tracker()