feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -1,15 +1,15 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from fastapi import FastAPI
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import asystem_app, client
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..config import SERVE_CONFIG_KEY_PREFIX
from ..api.endpoints import router, init_endpoints
from ..api.schemas import ServeRequest, ServerResponse
from dbgpt.serve.core.tests.conftest import client, asystem_app
from ..api.endpoints import init_endpoints, router
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX
@pytest.fixture(autouse=True)

View File

@@ -1,9 +1,12 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
from ..config import ServeConfig
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity, ServeDao
from ..config import ServeConfig
from ..models.models import ServeDao, ServeEntity
@pytest.fixture(autouse=True)
@@ -34,6 +37,8 @@ def default_entity_dict():
"content": "Write a qsort function in python.",
"user_name": "zhangsan",
"sys_code": "dbgpt",
"prompt_language": "zh",
"model": "vicuna-13b-v1.5",
}
@@ -60,7 +65,14 @@ def test_entity_create(default_entity_dict):
def test_entity_unique_key(default_entity_dict):
ServeEntity.create(**default_entity_dict)
with pytest.raises(Exception):
ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
ServeEntity.create(
**{
"prompt_name": "my_prompt_1",
"sys_code": "dbgpt",
"prompt_language": "zh",
"model": "vicuna-13b-v1.5",
}
)
def test_entity_get(default_entity_dict):

View File

@@ -0,0 +1,144 @@
import pytest
from dbgpt.core.interface.prompt import PromptManager, PromptTemplate
from dbgpt.storage.metadata import db
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from ..models.prompt_template_adapter import PromptTemplateAdapter, ServeEntity
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def db_url():
"""Use in-memory SQLite database for testing"""
return "sqlite:///:memory:"
@pytest.fixture
def db_manager(db_url):
db.init_db(db_url)
db.create_all()
return db
@pytest.fixture
def storage_adapter():
return PromptTemplateAdapter()
@pytest.fixture
def storage(db_manager, serializer, storage_adapter):
storage = SQLAlchemyStorage(
db_manager,
ServeEntity,
storage_adapter,
serializer,
)
return storage
@pytest.fixture
def prompt_manager(storage):
return PromptManager(storage)
def test_save(prompt_manager: PromptManager):
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name="hello",
)
with db.session() as session:
# Query from database
result = (
session.query(ServeEntity).filter(ServeEntity.prompt_name == "hello").all()
)
assert len(result) == 1
assert result[0].prompt_name == "hello"
assert result[0].content == "hello {input}"
assert result[0].input_variables == "input"
with db.session() as session:
assert session.query(ServeEntity).count() == 1
assert (
session.query(ServeEntity)
.filter(ServeEntity.prompt_name == "not exist prompt name")
.count()
== 0
)
def test_prefer_query_language(prompt_manager: PromptManager):
for language in ["en", "zh"]:
prompt_template = PromptTemplate(
template="test",
input_variables=[],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name="test_prompt",
prompt_language=language,
)
# Prefer zh, and zh exists, will return zh prompt template
result = prompt_manager.prefer_query("test_prompt", prefer_prompt_language="zh")
assert len(result) == 1
assert result[0].content == "test"
assert result[0].prompt_language == "zh"
# Prefer language not exists, will return all prompt templates of this name
result = prompt_manager.prefer_query(
"test_prompt", prefer_prompt_language="not_exist"
)
assert len(result) == 2
def test_prefer_query_model(prompt_manager: PromptManager):
for model in ["model1", "model2"]:
prompt_template = PromptTemplate(
template="test",
input_variables=[],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name="test_prompt",
model=model,
)
# Prefer model1, and model1 exists, will return model1 prompt template
result = prompt_manager.prefer_query("test_prompt", prefer_model="model1")
assert len(result) == 1
assert result[0].content == "test"
assert result[0].model == "model1"
# Prefer model not exists, will return all prompt templates of this name
result = prompt_manager.prefer_query("test_prompt", prefer_model="not_exist")
assert len(result) == 2
def test_list(prompt_manager: PromptManager):
for i in range(10):
prompt_template = PromptTemplate(
template="test",
input_variables=[],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name=f"test_prompt_{i}",
sys_code="dbgpt" if i % 2 == 0 else "not_dbgpt",
)
# Test list all
result = prompt_manager.list()
assert len(result) == 10
for i in range(10):
assert len(prompt_manager.list(prompt_name=f"test_prompt_{i}")) == 1
assert len(prompt_manager.list(sys_code="dbgpt")) == 5

View File

@@ -1,11 +1,13 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.serve.core.tests.conftest import system_app
from ..models.models import ServeEntity
import pytest
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import system_app
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity
from ..service.service import Service