mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 22:07:48 +00:00
144 lines
4.7 KiB
Python
144 lines
4.7 KiB
Python
import asyncio
|
|
from datetime import datetime, timedelta
|
|
|
|
import pytest
|
|
|
|
from dbgpt.model.base import ModelInstance
|
|
from dbgpt.model.cluster.registry import EmbeddedModelRegistry
|
|
|
|
|
|
@pytest.fixture
|
|
def model_registry():
|
|
return EmbeddedModelRegistry()
|
|
|
|
|
|
@pytest.fixture
|
|
def model_instance():
|
|
return ModelInstance(
|
|
model_name="test_model",
|
|
host="192.168.1.1",
|
|
port=5000,
|
|
)
|
|
|
|
|
|
# Async function to test the registry
|
|
@pytest.mark.asyncio
|
|
async def test_register_instance(model_registry, model_instance):
|
|
"""
|
|
Test if an instance can be registered correctly
|
|
"""
|
|
assert await model_registry.register_instance(model_instance) == True
|
|
assert len(model_registry.registry[model_instance.model_name]) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deregister_instance(model_registry, model_instance):
|
|
"""
|
|
Test if an instance can be deregistered correctly
|
|
"""
|
|
await model_registry.register_instance(model_instance)
|
|
assert await model_registry.deregister_instance(model_instance) == True
|
|
assert not model_registry.registry[model_instance.model_name][0].healthy
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_all_instances(model_registry, model_instance):
|
|
"""
|
|
Test if all instances can be retrieved, with and without the healthy_only filter
|
|
"""
|
|
await model_registry.register_instance(model_instance)
|
|
assert len(await model_registry.get_all_instances(model_instance.model_name)) == 1
|
|
assert (
|
|
len(
|
|
await model_registry.get_all_instances(
|
|
model_instance.model_name, healthy_only=True
|
|
)
|
|
)
|
|
== 1
|
|
)
|
|
model_instance.healthy = False
|
|
assert (
|
|
len(
|
|
await model_registry.get_all_instances(
|
|
model_instance.model_name, healthy_only=True
|
|
)
|
|
)
|
|
== 0
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_select_one_health_instance(model_registry, model_instance):
|
|
"""
|
|
Test if a single healthy instance can be selected
|
|
"""
|
|
await model_registry.register_instance(model_instance)
|
|
selected_instance = await model_registry.select_one_health_instance(
|
|
model_instance.model_name
|
|
)
|
|
assert selected_instance is not None
|
|
assert selected_instance.healthy
|
|
assert selected_instance.enabled
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_heartbeat(model_registry, model_instance):
|
|
"""
|
|
Test if a heartbeat can be sent and that it correctly updates the last_heartbeat timestamp
|
|
"""
|
|
await model_registry.register_instance(model_instance)
|
|
last_heartbeat = datetime.now() - timedelta(seconds=10)
|
|
model_instance.last_heartbeat = last_heartbeat
|
|
assert await model_registry.send_heartbeat(model_instance) == True
|
|
assert (
|
|
model_registry.registry[model_instance.model_name][0].last_heartbeat
|
|
> last_heartbeat
|
|
)
|
|
assert model_registry.registry[model_instance.model_name][0].healthy == True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_heartbeat_timeout(model_registry, model_instance):
|
|
"""
|
|
Test if an instance is marked as unhealthy when the heartbeat is not sent within the timeout
|
|
"""
|
|
model_registry = EmbeddedModelRegistry(1, 1)
|
|
await model_registry.register_instance(model_instance)
|
|
model_registry.registry[model_instance.model_name][
|
|
0
|
|
].last_heartbeat = datetime.now() - timedelta(
|
|
seconds=model_registry.heartbeat_timeout_secs + 1
|
|
)
|
|
await asyncio.sleep(model_registry.heartbeat_interval_secs + 1)
|
|
assert not model_registry.registry[model_instance.model_name][0].healthy
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_instances(model_registry, model_instance):
|
|
"""
|
|
Test if multiple instances of the same model are handled correctly
|
|
"""
|
|
model_instance2 = ModelInstance(
|
|
model_name="test_model",
|
|
host="192.168.1.2",
|
|
port=5000,
|
|
)
|
|
await model_registry.register_instance(model_instance)
|
|
await model_registry.register_instance(model_instance2)
|
|
assert len(await model_registry.get_all_instances(model_instance.model_name)) == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_same_model_name_different_ip_port(model_registry):
|
|
"""
|
|
Test if instances with the same model name but different IP and port are handled correctly
|
|
"""
|
|
instance1 = ModelInstance(model_name="test_model", host="192.168.1.1", port=5000)
|
|
instance2 = ModelInstance(model_name="test_model", host="192.168.1.2", port=6000)
|
|
await model_registry.register_instance(instance1)
|
|
await model_registry.register_instance(instance2)
|
|
instances = await model_registry.get_all_instances("test_model")
|
|
assert len(instances) == 2
|
|
assert instances[0].host != instances[1].host
|
|
assert instances[0].port != instances[1].port
|