mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 20:10:08 +00:00
feat(core): More AWEL operators and new prompt manager API (#972)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
from typing import Optional, List
|
||||
from functools import cache
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result
|
||||
from dbgpt.util import PaginationResult
|
||||
from .schemas import ServeRequest, ServerResponse
|
||||
|
||||
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from ..config import APP_NAME, SERVE_APP_NAME, ServeConfig, SERVE_SERVICE_COMPONENT_NAME
|
||||
from .schemas import ServeRequest, ServerResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@@ -1,6 +1,8 @@
|
||||
# Define your Pydantic schemas here
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
|
||||
|
@@ -1,9 +1,8 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.serve.core import BaseServeConfig
|
||||
|
||||
|
||||
APP_NAME = "prompt"
|
||||
SERVE_APP_NAME = "dbgpt_serve_prompt"
|
||||
SERVE_APP_NAME_HUMP = "dbgpt_serve_Prompt"
|
||||
|
@@ -1,33 +1,64 @@
|
||||
"""This is an auto-generated model file
|
||||
You can define your own models and DAOs here
|
||||
"""
|
||||
from typing import Union, Any, Dict
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Index, Text, DateTime, UniqueConstraint
|
||||
from dbgpt.storage.metadata import Model, BaseDao, db
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model, db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import ServeConfig, SERVER_APP_TABLE_NAME
|
||||
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
|
||||
|
||||
|
||||
class ServeEntity(Model):
|
||||
__tablename__ = "prompt_manage"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("prompt_name", "sys_code", name="uk_prompt_name_sys_code"),
|
||||
UniqueConstraint(
|
||||
"prompt_name",
|
||||
"sys_code",
|
||||
"prompt_language",
|
||||
"model",
|
||||
name="uk_prompt_name_sys_code",
|
||||
),
|
||||
)
|
||||
id = Column(Integer, primary_key=True, comment="Auto increment id")
|
||||
|
||||
chat_scene = Column(String(100))
|
||||
sub_chat_scene = Column(String(100))
|
||||
prompt_type = Column(String(100))
|
||||
prompt_name = Column(String(512))
|
||||
content = Column(Text)
|
||||
user_name = Column(String(128))
|
||||
chat_scene = Column(String(100), comment="Chat scene")
|
||||
sub_chat_scene = Column(String(100), comment="Sub chat scene")
|
||||
prompt_type = Column(String(100), comment="Prompt type(eg: common, private)")
|
||||
prompt_name = Column(String(256), comment="Prompt name")
|
||||
content = Column(Text, comment="Prompt content")
|
||||
input_variables = Column(
|
||||
String(1024), nullable=True, comment="Prompt input variables(split by comma))"
|
||||
)
|
||||
model = Column(
|
||||
String(128),
|
||||
nullable=True,
|
||||
comment="Prompt model name(we can use different models for different prompt",
|
||||
)
|
||||
prompt_language = Column(
|
||||
String(32), index=True, nullable=True, comment="Prompt language(eg:en, zh-cn)"
|
||||
)
|
||||
prompt_format = Column(
|
||||
String(32),
|
||||
index=True,
|
||||
nullable=True,
|
||||
default="f-string",
|
||||
comment="Prompt format(eg: f-string, jinja2)",
|
||||
)
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
|
||||
|
||||
def __repr__(self):
|
||||
return f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
return (
|
||||
f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', "
|
||||
f"prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',"
|
||||
f"user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
)
|
||||
|
||||
|
||||
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
|
56
dbgpt/serve/prompt/models/prompt_template_adapter.py
Normal file
56
dbgpt/serve/prompt/models/prompt_template_adapter.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dbgpt.core.interface.prompt import PromptTemplateIdentifier, StoragePromptTemplate
|
||||
from dbgpt.core.interface.storage import StorageItemAdapter
|
||||
|
||||
from .models import ServeEntity
|
||||
|
||||
|
||||
class PromptTemplateAdapter(StorageItemAdapter[StoragePromptTemplate, ServeEntity]):
|
||||
def to_storage_format(self, item: StoragePromptTemplate) -> ServeEntity:
|
||||
return ServeEntity(
|
||||
chat_scene=item.chat_scene,
|
||||
sub_chat_scene=item.sub_chat_scene,
|
||||
prompt_type=item.prompt_type,
|
||||
prompt_name=item.prompt_name,
|
||||
content=item.content,
|
||||
input_variables=item.input_variables,
|
||||
model=item.model,
|
||||
prompt_language=item.prompt_language,
|
||||
prompt_format=item.prompt_format,
|
||||
user_name=item.user_name,
|
||||
sys_code=item.sys_code,
|
||||
)
|
||||
|
||||
def from_storage_format(self, model: ServeEntity) -> StoragePromptTemplate:
|
||||
return StoragePromptTemplate(
|
||||
chat_scene=model.chat_scene,
|
||||
sub_chat_scene=model.sub_chat_scene,
|
||||
prompt_type=model.prompt_type,
|
||||
prompt_name=model.prompt_name,
|
||||
content=model.content,
|
||||
input_variables=model.input_variables,
|
||||
model=model.model,
|
||||
prompt_language=model.prompt_language,
|
||||
prompt_format=model.prompt_format,
|
||||
user_name=model.user_name,
|
||||
sys_code=model.sys_code,
|
||||
)
|
||||
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[ServeEntity],
|
||||
resource_id: PromptTemplateIdentifier,
|
||||
**kwargs,
|
||||
):
|
||||
session: Session = kwargs.get("session")
|
||||
if session is None:
|
||||
raise Exception("session is None")
|
||||
query_obj = session.query(ServeEntity)
|
||||
for key, value in resource_id.to_dict().items():
|
||||
if value is None:
|
||||
continue
|
||||
query_obj = query_obj.filter(getattr(ServeEntity, key) == value)
|
||||
return query_obj
|
@@ -1,17 +1,80 @@
|
||||
from typing import List, Optional
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from .api.endpoints import router, init_endpoints
|
||||
from sqlalchemy import URL
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.core import PromptManager
|
||||
|
||||
from ...storage.metadata import DatabaseManager
|
||||
from .api.endpoints import init_endpoints, router
|
||||
from .config import (
|
||||
APP_NAME,
|
||||
SERVE_APP_NAME,
|
||||
SERVE_APP_NAME_HUMP,
|
||||
APP_NAME,
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
ServeConfig,
|
||||
)
|
||||
from .models.prompt_template_adapter import PromptTemplateAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Serve(BaseComponent):
|
||||
"""Serve component
|
||||
|
||||
Examples:
|
||||
|
||||
Register the serve component to the system app
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from fastapi import FastAPI
|
||||
from dbgpt import SystemApp
|
||||
from dbgpt.core import PromptTemplate
|
||||
from dbgpt.serve.prompt.serve import Serve, SERVE_APP_NAME
|
||||
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
system_app.register(Serve, api_prefix="/api/v1/prompt")
|
||||
# Run before start hook
|
||||
system_app.before_start()
|
||||
|
||||
prompt_serve = system_app.get_component(SERVE_APP_NAME, Serve)
|
||||
|
||||
# Get the prompt manager
|
||||
prompt_manager = prompt_serve.prompt_manager
|
||||
prompt_manager.save(
|
||||
PromptTemplate(template="Hello {name}", input_variables=["name"]),
|
||||
prompt_name="prompt_name",
|
||||
)
|
||||
|
||||
With your database url
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from fastapi import FastAPI
|
||||
from dbgpt import SystemApp
|
||||
from dbgpt.core import PromptTemplate
|
||||
from dbgpt.serve.prompt.serve import Serve, SERVE_APP_NAME
|
||||
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
system_app.register(Serve, api_prefix="/api/v1/prompt", db_url_or_db="sqlite:///:memory:", try_create_tables=True)
|
||||
# Run before start hook
|
||||
system_app.before_start()
|
||||
|
||||
prompt_serve = system_app.get_component(SERVE_APP_NAME, Serve)
|
||||
|
||||
# Get the prompt manager
|
||||
prompt_manager = prompt_serve.prompt_manager
|
||||
prompt_manager.save(
|
||||
PromptTemplate(template="Hello {name}", input_variables=["name"]),
|
||||
prompt_name="prompt_name",
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
name = SERVE_APP_NAME
|
||||
|
||||
def __init__(
|
||||
@@ -19,12 +82,17 @@ class Serve(BaseComponent):
|
||||
system_app: SystemApp,
|
||||
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
|
||||
tags: Optional[List[str]] = None,
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
):
|
||||
if tags is None:
|
||||
tags = [SERVE_APP_NAME_HUMP]
|
||||
self._system_app = None
|
||||
self._api_prefix = api_prefix
|
||||
self._tags = tags
|
||||
self._prompt_manager = None
|
||||
self._db_url_or_db = db_url_or_db
|
||||
self._try_create_tables = try_create_tables
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self._system_app = system_app
|
||||
@@ -33,10 +101,37 @@ class Serve(BaseComponent):
|
||||
)
|
||||
init_endpoints(self._system_app)
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> PromptManager:
|
||||
"""Get the prompt manager of the serve app with db storage"""
|
||||
return self._prompt_manager
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the start of the application.
|
||||
|
||||
You can do some initialization here.
|
||||
"""
|
||||
# import your own module here to ensure the module is loaded before the application starts
|
||||
from dbgpt.core.interface.prompt import PromptManager
|
||||
from dbgpt.storage.metadata import Model, db
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from .models.models import ServeEntity
|
||||
|
||||
init_db = self._db_url_or_db or db
|
||||
init_db = DatabaseManager.build_from(init_db, base=Model)
|
||||
if self._try_create_tables:
|
||||
try:
|
||||
init_db.create_all()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create tables: {e}")
|
||||
storage_adapter = PromptTemplateAdapter()
|
||||
serializer = JsonSerializer()
|
||||
storage = SQLAlchemyStorage(
|
||||
init_db,
|
||||
ServeEntity,
|
||||
storage_adapter,
|
||||
serializer,
|
||||
)
|
||||
self._prompt_manager = PromptManager(storage)
|
||||
|
@@ -1,11 +1,13 @@
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.serve.core import BaseService
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVE_SERVICE_COMPONENT_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
|
||||
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
|
@@ -1,15 +1,15 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core.tests.conftest import asystem_app, client
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.util import PaginationResult
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
from ..api.endpoints import router, init_endpoints
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
|
||||
from dbgpt.serve.core.tests.conftest import client, asystem_app
|
||||
from ..api.endpoints import init_endpoints, router
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@@ -1,9 +1,12 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.storage.metadata import db
|
||||
from ..config import ServeConfig
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity, ServeDao
|
||||
from ..config import ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -34,6 +37,8 @@ def default_entity_dict():
|
||||
"content": "Write a qsort function in python.",
|
||||
"user_name": "zhangsan",
|
||||
"sys_code": "dbgpt",
|
||||
"prompt_language": "zh",
|
||||
"model": "vicuna-13b-v1.5",
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +65,14 @@ def test_entity_create(default_entity_dict):
|
||||
def test_entity_unique_key(default_entity_dict):
|
||||
ServeEntity.create(**default_entity_dict)
|
||||
with pytest.raises(Exception):
|
||||
ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
|
||||
ServeEntity.create(
|
||||
**{
|
||||
"prompt_name": "my_prompt_1",
|
||||
"sys_code": "dbgpt",
|
||||
"prompt_language": "zh",
|
||||
"model": "vicuna-13b-v1.5",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_entity_get(default_entity_dict):
|
||||
|
144
dbgpt/serve/prompt/tests/test_prompt_template_adapter.py
Normal file
144
dbgpt/serve/prompt/tests/test_prompt_template_adapter.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.prompt import PromptManager, PromptTemplate
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from ..models.prompt_template_adapter import PromptTemplateAdapter, ServeEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serializer():
|
||||
return JsonSerializer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_url():
|
||||
"""Use in-memory SQLite database for testing"""
|
||||
return "sqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_manager(db_url):
|
||||
db.init_db(db_url)
|
||||
db.create_all()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_adapter():
|
||||
return PromptTemplateAdapter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(db_manager, serializer, storage_adapter):
|
||||
storage = SQLAlchemyStorage(
|
||||
db_manager,
|
||||
ServeEntity,
|
||||
storage_adapter,
|
||||
serializer,
|
||||
)
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_manager(storage):
|
||||
return PromptManager(storage)
|
||||
|
||||
|
||||
def test_save(prompt_manager: PromptManager):
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template,
|
||||
prompt_name="hello",
|
||||
)
|
||||
|
||||
with db.session() as session:
|
||||
# Query from database
|
||||
result = (
|
||||
session.query(ServeEntity).filter(ServeEntity.prompt_name == "hello").all()
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].prompt_name == "hello"
|
||||
assert result[0].content == "hello {input}"
|
||||
assert result[0].input_variables == "input"
|
||||
with db.session() as session:
|
||||
assert session.query(ServeEntity).count() == 1
|
||||
assert (
|
||||
session.query(ServeEntity)
|
||||
.filter(ServeEntity.prompt_name == "not exist prompt name")
|
||||
.count()
|
||||
== 0
|
||||
)
|
||||
|
||||
|
||||
def test_prefer_query_language(prompt_manager: PromptManager):
|
||||
for language in ["en", "zh"]:
|
||||
prompt_template = PromptTemplate(
|
||||
template="test",
|
||||
input_variables=[],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template,
|
||||
prompt_name="test_prompt",
|
||||
prompt_language=language,
|
||||
)
|
||||
# Prefer zh, and zh exists, will return zh prompt template
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_prompt_language="zh")
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "test"
|
||||
assert result[0].prompt_language == "zh"
|
||||
# Prefer language not exists, will return all prompt templates of this name
|
||||
result = prompt_manager.prefer_query(
|
||||
"test_prompt", prefer_prompt_language="not_exist"
|
||||
)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_prefer_query_model(prompt_manager: PromptManager):
|
||||
for model in ["model1", "model2"]:
|
||||
prompt_template = PromptTemplate(
|
||||
template="test",
|
||||
input_variables=[],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template,
|
||||
prompt_name="test_prompt",
|
||||
model=model,
|
||||
)
|
||||
# Prefer model1, and model1 exists, will return model1 prompt template
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_model="model1")
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "test"
|
||||
assert result[0].model == "model1"
|
||||
# Prefer model not exists, will return all prompt templates of this name
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_model="not_exist")
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_list(prompt_manager: PromptManager):
|
||||
for i in range(10):
|
||||
prompt_template = PromptTemplate(
|
||||
template="test",
|
||||
input_variables=[],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template,
|
||||
prompt_name=f"test_prompt_{i}",
|
||||
sys_code="dbgpt" if i % 2 == 0 else "not_dbgpt",
|
||||
)
|
||||
# Test list all
|
||||
result = prompt_manager.list()
|
||||
assert len(result) == 10
|
||||
|
||||
for i in range(10):
|
||||
assert len(prompt_manager.list(prompt_name=f"test_prompt_{i}")) == 1
|
||||
assert len(prompt_manager.list(sys_code="dbgpt")) == 5
|
@@ -1,11 +1,13 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
|
||||
from ..models.models import ServeEntity
|
||||
import pytest
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
from dbgpt.storage.metadata import db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity
|
||||
from ..service.service import Service
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user