mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-04 13:19:43 +00:00
* add api server * fix * add * add completion service and fix bug * add generation config * revise shardformer * fix bugs * add docstrings and fix some bugs * fix bugs and add choices for prompt template
78 lines
2.2 KiB
Python
78 lines
2.2 KiB
Python
import pytest
|
|
|
|
from colossalai.inference.core.async_engine import RequestTracker
|
|
from colossalai.inference.struct import Sequence
|
|
|
|
|
|
class SampleEvent:
|
|
def __init__(self):
|
|
self.flag = False
|
|
|
|
def set(self):
|
|
self.flag = True
|
|
|
|
def clear(self):
|
|
self.flag = False
|
|
|
|
|
|
def test_request_tracker():
|
|
tracker = RequestTracker()
|
|
tracker.new_requests_event = SampleEvent()
|
|
stream_1 = tracker.add_request(1)
|
|
assert tracker.new_requests_event.flag
|
|
new, finished = tracker.get_new_and_finished_requests()
|
|
assert not tracker.new_requests_event.flag
|
|
assert len(new) == 1
|
|
assert new[0]["request_id"] == 1
|
|
assert not finished
|
|
assert not stream_1.finished
|
|
|
|
stream_2 = tracker.add_request(2)
|
|
stream_3 = tracker.add_request(3)
|
|
assert tracker.new_requests_event.flag
|
|
new, finished = tracker.get_new_and_finished_requests()
|
|
assert not tracker.new_requests_event.flag
|
|
assert len(new) == 2
|
|
assert new[0]["request_id"] == 2
|
|
assert new[1]["request_id"] == 3
|
|
assert not finished
|
|
assert not stream_2.finished
|
|
assert not stream_3.finished
|
|
|
|
# request_ids must be unique
|
|
with pytest.raises(KeyError):
|
|
tracker.add_request(1)
|
|
assert not tracker.new_requests_event.flag
|
|
|
|
tracker.abort_request(1)
|
|
new, finished = tracker.get_new_and_finished_requests()
|
|
assert len(finished) == 1
|
|
assert 1 in finished
|
|
assert not new
|
|
assert stream_1.finished
|
|
|
|
stream_4 = tracker.add_request(4)
|
|
tracker.abort_request(4)
|
|
assert tracker.new_requests_event.flag
|
|
new, finished = tracker.get_new_and_finished_requests()
|
|
assert len(finished) == 1
|
|
assert 4 in finished
|
|
assert not new
|
|
assert stream_4.finished
|
|
|
|
stream_5 = tracker.add_request(5)
|
|
assert tracker.new_requests_event.flag
|
|
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
|
|
new, finished = tracker.get_new_and_finished_requests()
|
|
assert not tracker.new_requests_event.flag
|
|
assert len(finished) == 1
|
|
assert 2 in finished
|
|
assert len(new) == 1
|
|
assert new[0]["request_id"] == 5
|
|
assert stream_2.finished
|
|
assert not stream_5.finished
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_request_tracker()
|