feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -1,15 +1,16 @@
from typing import Optional, List
from functools import cache
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
from dbgpt.util import PaginationResult
from .schemas import ServeRequest, ServerResponse
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from ..config import APP_NAME, SERVE_APP_NAME, ServeConfig, SERVE_SERVICE_COMPONENT_NAME
from .schemas import ServeRequest, ServerResponse
router = APIRouter()

View File

@@ -1,6 +1,8 @@
# Define your Pydantic schemas here
from typing import Optional
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP

View File

@@ -1,9 +1,8 @@
from typing import Optional
from dataclasses import dataclass, field
from typing import Optional
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "prompt"
SERVE_APP_NAME = "dbgpt_serve_prompt"
SERVE_APP_NAME_HUMP = "dbgpt_serve_Prompt"

View File

@@ -1,33 +1,64 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
from typing import Union, Any, Dict
from datetime import datetime
from sqlalchemy import Column, Integer, String, Index, Text, DateTime, UniqueConstraint
from dbgpt.storage.metadata import Model, BaseDao, db
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model, db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import ServeConfig, SERVER_APP_TABLE_NAME
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
class ServeEntity(Model):
__tablename__ = "prompt_manage"
__table_args__ = (
UniqueConstraint("prompt_name", "sys_code", name="uk_prompt_name_sys_code"),
UniqueConstraint(
"prompt_name",
"sys_code",
"prompt_language",
"model",
name="uk_prompt_name_sys_code",
),
)
id = Column(Integer, primary_key=True, comment="Auto increment id")
chat_scene = Column(String(100))
sub_chat_scene = Column(String(100))
prompt_type = Column(String(100))
prompt_name = Column(String(512))
content = Column(Text)
user_name = Column(String(128))
chat_scene = Column(String(100), comment="Chat scene")
sub_chat_scene = Column(String(100), comment="Sub chat scene")
prompt_type = Column(String(100), comment="Prompt type(eg: common, private)")
prompt_name = Column(String(256), comment="Prompt name")
content = Column(Text, comment="Prompt content")
input_variables = Column(
String(1024), nullable=True, comment="Prompt input variables(split by comma))"
)
model = Column(
String(128),
nullable=True,
comment="Prompt model name(we can use different models for different prompt",
)
prompt_language = Column(
String(32), index=True, nullable=True, comment="Prompt language(eg:en, zh-cn)"
)
prompt_format = Column(
String(32),
index=True,
nullable=True,
default="f-string",
comment="Prompt format(eg: f-string, jinja2)",
)
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
def __repr__(self):
return f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
return (
f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', "
f"prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',"
f"user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
)
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):

View File

@@ -0,0 +1,56 @@
from typing import Type
from sqlalchemy.orm import Session
from dbgpt.core.interface.prompt import PromptTemplateIdentifier, StoragePromptTemplate
from dbgpt.core.interface.storage import StorageItemAdapter
from .models import ServeEntity
class PromptTemplateAdapter(StorageItemAdapter[StoragePromptTemplate, ServeEntity]):
def to_storage_format(self, item: StoragePromptTemplate) -> ServeEntity:
return ServeEntity(
chat_scene=item.chat_scene,
sub_chat_scene=item.sub_chat_scene,
prompt_type=item.prompt_type,
prompt_name=item.prompt_name,
content=item.content,
input_variables=item.input_variables,
model=item.model,
prompt_language=item.prompt_language,
prompt_format=item.prompt_format,
user_name=item.user_name,
sys_code=item.sys_code,
)
def from_storage_format(self, model: ServeEntity) -> StoragePromptTemplate:
return StoragePromptTemplate(
chat_scene=model.chat_scene,
sub_chat_scene=model.sub_chat_scene,
prompt_type=model.prompt_type,
prompt_name=model.prompt_name,
content=model.content,
input_variables=model.input_variables,
model=model.model,
prompt_language=model.prompt_language,
prompt_format=model.prompt_format,
user_name=model.user_name,
sys_code=model.sys_code,
)
def get_query_for_identifier(
self,
storage_format: Type[ServeEntity],
resource_id: PromptTemplateIdentifier,
**kwargs,
):
session: Session = kwargs.get("session")
if session is None:
raise Exception("session is None")
query_obj = session.query(ServeEntity)
for key, value in resource_id.to_dict().items():
if value is None:
continue
query_obj = query_obj.filter(getattr(ServeEntity, key) == value)
return query_obj

View File

@@ -1,17 +1,80 @@
from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
import logging
from typing import List, Optional, Union
from .api.endpoints import router, init_endpoints
from sqlalchemy import URL
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.core import PromptManager
from ...storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
APP_NAME,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
from .models.prompt_template_adapter import PromptTemplateAdapter
logger = logging.getLogger(__name__)
class Serve(BaseComponent):
"""Serve component
Examples:
Register the serve component to the system app
.. code-block:: python
from fastapi import FastAPI
from dbgpt import SystemApp
from dbgpt.core import PromptTemplate
from dbgpt.serve.prompt.serve import Serve, SERVE_APP_NAME
app = FastAPI()
system_app = SystemApp(app)
system_app.register(Serve, api_prefix="/api/v1/prompt")
# Run before start hook
system_app.before_start()
prompt_serve = system_app.get_component(SERVE_APP_NAME, Serve)
# Get the prompt manager
prompt_manager = prompt_serve.prompt_manager
prompt_manager.save(
PromptTemplate(template="Hello {name}", input_variables=["name"]),
prompt_name="prompt_name",
)
With your database url
.. code-block:: python
from fastapi import FastAPI
from dbgpt import SystemApp
from dbgpt.core import PromptTemplate
from dbgpt.serve.prompt.serve import Serve, SERVE_APP_NAME
app = FastAPI()
system_app = SystemApp(app)
system_app.register(Serve, api_prefix="/api/v1/prompt", db_url_or_db="sqlite:///:memory:", try_create_tables=True)
# Run before start hook
system_app.before_start()
prompt_serve = system_app.get_component(SERVE_APP_NAME, Serve)
# Get the prompt manager
prompt_manager = prompt_serve.prompt_manager
prompt_manager.save(
PromptTemplate(template="Hello {name}", input_variables=["name"]),
prompt_name="prompt_name",
)
"""
name = SERVE_APP_NAME
def __init__(
@@ -19,12 +82,17 @@ class Serve(BaseComponent):
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if tags is None:
tags = [SERVE_APP_NAME_HUMP]
self._system_app = None
self._api_prefix = api_prefix
self._tags = tags
self._prompt_manager = None
self._db_url_or_db = db_url_or_db
self._try_create_tables = try_create_tables
def init_app(self, system_app: SystemApp):
self._system_app = system_app
@@ -33,10 +101,37 @@ class Serve(BaseComponent):
)
init_endpoints(self._system_app)
@property
def prompt_manager(self) -> PromptManager:
"""Get the prompt manager of the serve app with db storage"""
return self._prompt_manager
def before_start(self):
"""Called before the start of the application.
You can do some initialization here.
"""
# import your own module here to ensure the module is loaded before the application starts
from dbgpt.core.interface.prompt import PromptManager
from dbgpt.storage.metadata import Model, db
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from .models.models import ServeEntity
init_db = self._db_url_or_db or db
init_db = DatabaseManager.build_from(init_db, base=Model)
if self._try_create_tables:
try:
init_db.create_all()
except Exception as e:
logger.warning(f"Failed to create tables: {e}")
storage_adapter = PromptTemplateAdapter()
serializer = JsonSerializer()
storage = SQLAlchemyStorage(
init_db,
ServeEntity,
storage_adapter,
serializer,
)
self._prompt_manager = PromptManager(storage)

View File

@@ -1,11 +1,13 @@
from typing import Optional, List
from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.serve.core import BaseService
from ..models.models import ServeDao, ServeEntity
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_SERVICE_COMPONENT_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):

View File

@@ -1,15 +1,15 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from fastapi import FastAPI
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import asystem_app, client
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..config import SERVE_CONFIG_KEY_PREFIX
from ..api.endpoints import router, init_endpoints
from ..api.schemas import ServeRequest, ServerResponse
from dbgpt.serve.core.tests.conftest import client, asystem_app
from ..api.endpoints import init_endpoints, router
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX
@pytest.fixture(autouse=True)

View File

@@ -1,9 +1,12 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
from ..config import ServeConfig
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity, ServeDao
from ..config import ServeConfig
from ..models.models import ServeDao, ServeEntity
@pytest.fixture(autouse=True)
@@ -34,6 +37,8 @@ def default_entity_dict():
"content": "Write a qsort function in python.",
"user_name": "zhangsan",
"sys_code": "dbgpt",
"prompt_language": "zh",
"model": "vicuna-13b-v1.5",
}
@@ -60,7 +65,14 @@ def test_entity_create(default_entity_dict):
def test_entity_unique_key(default_entity_dict):
ServeEntity.create(**default_entity_dict)
with pytest.raises(Exception):
ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
ServeEntity.create(
**{
"prompt_name": "my_prompt_1",
"sys_code": "dbgpt",
"prompt_language": "zh",
"model": "vicuna-13b-v1.5",
}
)
def test_entity_get(default_entity_dict):

View File

@@ -0,0 +1,144 @@
import pytest
from dbgpt.core.interface.prompt import PromptManager, PromptTemplate
from dbgpt.storage.metadata import db
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from ..models.prompt_template_adapter import PromptTemplateAdapter, ServeEntity
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def db_url():
"""Use in-memory SQLite database for testing"""
return "sqlite:///:memory:"
@pytest.fixture
def db_manager(db_url):
db.init_db(db_url)
db.create_all()
return db
@pytest.fixture
def storage_adapter():
return PromptTemplateAdapter()
@pytest.fixture
def storage(db_manager, serializer, storage_adapter):
storage = SQLAlchemyStorage(
db_manager,
ServeEntity,
storage_adapter,
serializer,
)
return storage
@pytest.fixture
def prompt_manager(storage):
return PromptManager(storage)
def test_save(prompt_manager: PromptManager):
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name="hello",
)
with db.session() as session:
# Query from database
result = (
session.query(ServeEntity).filter(ServeEntity.prompt_name == "hello").all()
)
assert len(result) == 1
assert result[0].prompt_name == "hello"
assert result[0].content == "hello {input}"
assert result[0].input_variables == "input"
with db.session() as session:
assert session.query(ServeEntity).count() == 1
assert (
session.query(ServeEntity)
.filter(ServeEntity.prompt_name == "not exist prompt name")
.count()
== 0
)
def test_prefer_query_language(prompt_manager: PromptManager):
for language in ["en", "zh"]:
prompt_template = PromptTemplate(
template="test",
input_variables=[],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name="test_prompt",
prompt_language=language,
)
# Prefer zh, and zh exists, will return zh prompt template
result = prompt_manager.prefer_query("test_prompt", prefer_prompt_language="zh")
assert len(result) == 1
assert result[0].content == "test"
assert result[0].prompt_language == "zh"
# Prefer language not exists, will return all prompt templates of this name
result = prompt_manager.prefer_query(
"test_prompt", prefer_prompt_language="not_exist"
)
assert len(result) == 2
def test_prefer_query_model(prompt_manager: PromptManager):
for model in ["model1", "model2"]:
prompt_template = PromptTemplate(
template="test",
input_variables=[],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name="test_prompt",
model=model,
)
# Prefer model1, and model1 exists, will return model1 prompt template
result = prompt_manager.prefer_query("test_prompt", prefer_model="model1")
assert len(result) == 1
assert result[0].content == "test"
assert result[0].model == "model1"
# Prefer model not exists, will return all prompt templates of this name
result = prompt_manager.prefer_query("test_prompt", prefer_model="not_exist")
assert len(result) == 2
def test_list(prompt_manager: PromptManager):
for i in range(10):
prompt_template = PromptTemplate(
template="test",
input_variables=[],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name=f"test_prompt_{i}",
sys_code="dbgpt" if i % 2 == 0 else "not_dbgpt",
)
# Test list all
result = prompt_manager.list()
assert len(result) == 10
for i in range(10):
assert len(prompt_manager.list(prompt_name=f"test_prompt_{i}")) == 1
assert len(prompt_manager.list(sys_code="dbgpt")) == 5

View File

@@ -1,11 +1,13 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.serve.core.tests.conftest import system_app
from ..models.models import ServeEntity
import pytest
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import system_app
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity
from ..service.service import Service