feat: (0.6)New UI (#1855)

Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
明天
2024-08-21 17:37:45 +08:00
committed by GitHub
parent 3fc82693ba
commit b124ecc10b
824 changed files with 93371 additions and 2515 deletions

View File

@@ -1,8 +1,10 @@
import logging
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from starlette.responses import StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
@@ -10,7 +12,15 @@ from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import ServeRequest, ServerResponse
from .schemas import (
PromptDebugInput,
PromptType,
PromptVerifyInput,
ServeRequest,
ServerResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter()
@@ -66,25 +76,6 @@ async def check_api_key(
if request.url.path.startswith(f"/api/v1"):
return None
# if service.config.api_keys:
# api_keys = _parse_api_keys(service.config.api_keys)
# if auth is None or (token := auth.credentials) not in api_keys:
# raise HTTPException(
# status_code=401,
# detail={
# "error": {
# "message": "",
# "type": "invalid_request_error",
# "param": None,
# "code": "invalid_api_key",
# }
# },
# )
# return token
# else:
# # api_keys not set; allow all
# return None
@router.get("/health")
async def health():
@@ -132,7 +123,12 @@ async def update(
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
try:
data = service.update(request)
return Result.succ(data)
except Exception as e:
logger.exception("Update prompt failed!")
return Result.failed(msg=str(e))
@router.post(
@@ -195,6 +191,113 @@ async def query_page(
return Result.succ(service.get_list_by_page(request, page, page_size))
@router.get(
"/type/targets",
response_model=Result,
dependencies=[Depends(check_api_key)],
)
async def prompt_type_targets(
prompt_type: str = Query(
default=PromptType.NORMAL, description="Prompt template type"
),
service: Service = Depends(get_service),
) -> Result:
"""get Prompt type
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
return Result.succ(service.get_type_targets(prompt_type))
@router.post(
"/template/load",
response_model=Result,
dependencies=[Depends(check_api_key)],
)
async def load_template(
prompt_type: str = Query(
default=PromptType.NORMAL, description="Prompt template type"
),
target: Optional[str] = Query(
default=None, description="The target to load the template from"
),
service: Service = Depends(get_service),
) -> Result:
"""load Prompt from target
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
return Result.succ(service.load_template(prompt_type, target))
@router.post(
"/template/debug",
dependencies=[Depends(check_api_key)],
)
async def template_debug(
debug_input: PromptDebugInput,
service: Service = Depends(get_service),
):
"""test Prompt
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
try:
return StreamingResponse(
service.debug_prompt(
debug_input=debug_input,
),
headers=headers,
media_type="text/event-stream",
)
except Exception as e:
return Result.failed(msg=str(e))
@router.post(
"/response/verify",
response_model=Result[bool],
dependencies=[Depends(check_api_key)],
)
async def response_verify(
request: PromptVerifyInput,
service: Service = Depends(get_service),
) -> Result[bool]:
"""test Prompt
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
try:
return Result.succ(
service.verify_response(
request.llm_out, request.prompt_type, request.chat_scene
)
)
except Exception as e:
return Result.failed(msg=str(e))
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app

View File

@@ -1,4 +1,5 @@
# Define your Pydantic schemas here
from enum import Enum
from typing import Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
@@ -22,7 +23,11 @@ class ServeRequest(BaseModel):
description="The sub chat scene.",
examples=["sub_scene_1", "sub_scene_2", "sub_scene_3"],
)
prompt_code: Optional[str] = Field(
None,
description="The prompt code.",
examples=["test123", "test456"],
)
prompt_type: Optional[str] = Field(
None,
description="The prompt type, either common or private.",
@@ -51,7 +56,39 @@ class ServeRequest(BaseModel):
"This is a prompt for data analysis expert.",
],
)
response_schema: Optional[str] = Field(
None,
description="The prompt response schema.",
examples=[
"None",
'{"xx": "123"}',
],
)
input_variables: Optional[str] = Field(
None,
description="The prompt variables.",
examples=[
"display_type",
"resources",
],
)
model: Optional[str] = Field(
None,
description="The prompt can use model.",
examples=["vicuna13b", "chatgpt"],
)
prompt_language: Optional[str] = Field(
None,
description="The prompt language.",
examples=["en", "zh"],
)
user_code: Optional[str] = Field(
None,
description="The user id.",
examples=[""],
)
user_name: Optional[str] = Field(
None,
description="The user name.",
@@ -75,6 +112,11 @@ class ServerResponse(ServeRequest):
description="The prompt id.",
examples=[1, 2, 3],
)
prompt_code: Optional[str] = Field(
None,
description="The prompt code.",
examples=["xxxx1", "xxxx2", "xxxx3"],
)
gmt_created: Optional[str] = Field(
None,
description="The prompt created time.",
@@ -85,3 +127,37 @@ class ServerResponse(ServeRequest):
description="The prompt modified time.",
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
)
class PromptVerifyInput(ServeRequest):
llm_out: Optional[str] = Field(
None,
description="The llm out of prompt.",
)
class PromptDebugInput(ServeRequest):
input_values: Optional[dict] = Field(
None,
description="The prompt variables debug value.",
)
temperature: Optional[float] = Field(
default=0.5,
description="The prompt debug temperature.",
)
debug_model: Optional[str] = Field(
None,
description="The prompt debug model.",
examples=["vicuna13b", "chatgpt"],
)
user_input: Optional[str] = Field(
None,
description="The prompt debug user input.",
)
class PromptType(Enum):
AGENT = "Agent"
SCENE = "Scene"
NORMAL = "Normal"
EVALUATE = "Evaluate"

View File

@@ -1,12 +1,12 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
import uuid
from datetime import datetime
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from dbgpt._private.pydantic import model_to_dict
from dbgpt.storage.metadata import BaseDao, Model, db
from ..api.schemas import ServeRequest, ServerResponse
@@ -28,12 +28,14 @@ class ServeEntity(Model):
chat_scene = Column(String(100), comment="Chat scene")
sub_chat_scene = Column(String(100), comment="Sub chat scene")
prompt_code = Column(String(256), comment="Prompt Code")
prompt_type = Column(String(100), comment="Prompt type(eg: common, private)")
prompt_name = Column(String(256), comment="Prompt name")
content = Column(Text, comment="Prompt content")
input_variables = Column(
String(1024), nullable=True, comment="Prompt input variables(split by comma))"
)
response_schema = Column(Text, comment="Prompt response schema")
model = Column(
String(128),
nullable=True,
@@ -50,6 +52,7 @@ class ServeEntity(Model):
comment="Prompt format(eg: f-string, jinja2)",
)
prompt_desc = Column(String(512), nullable=True, comment="Prompt description")
user_code = Column(String(128), index=True, nullable=True, comment="User code")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
@@ -59,7 +62,7 @@ class ServeEntity(Model):
return (
f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', "
f"prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',"
f"user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
f"user_code='{self.user_code}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
)
@@ -79,10 +82,10 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
Returns:
T: The entity
"""
request_dict = (
model_to_dict(request) if isinstance(request, ServeRequest) else request
)
request_dict = request.dict() if isinstance(request, ServeRequest) else request
entity = ServeEntity(**request_dict)
if not entity.prompt_code:
entity.prompt_code = uuid.uuid4().hex
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
@@ -99,8 +102,10 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
sub_chat_scene=entity.sub_chat_scene,
prompt_type=entity.prompt_type,
prompt_name=entity.prompt_name,
prompt_code=entity.prompt_code,
content=entity.content,
prompt_desc=entity.prompt_desc,
user_code=entity.user_code,
user_name=entity.user_name,
sys_code=entity.sys_code,
)
@@ -121,11 +126,16 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
id=entity.id,
chat_scene=entity.chat_scene,
sub_chat_scene=entity.sub_chat_scene,
prompt_code=entity.prompt_code,
prompt_type=entity.prompt_type,
prompt_name=entity.prompt_name,
content=entity.content,
prompt_desc=entity.prompt_desc,
user_name=entity.user_name,
user_code=entity.user_code,
model=entity.model,
input_variables=entity.input_variables,
prompt_language=entity.prompt_language,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,

View File

@@ -1,14 +1,29 @@
from typing import List, Optional
import json
import logging
from typing import Dict, List, Optional, Type
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.agent import ConversableAgent, get_agent_manager
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelMetadata, ModelRequest, PromptTemplate
from dbgpt.core.interface.prompt import (
SystemPromptTemplate,
_get_string_template_vars,
get_template_vars,
)
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.json_utils import compare_json_properties_ex, find_json_objects
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.util.tracer import root_tracer
from ..api.schemas import ServeRequest, ServerResponse
from ..api.schemas import PromptDebugInput, PromptType, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
logger = logging.getLogger(__name__)
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for Prompt"""
@@ -27,8 +42,6 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
@@ -72,7 +85,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""
# Build the query request from the request
query_request = {
"prompt_name": request.prompt_name,
"prompt_code": request.prompt_code,
"sys_code": request.sys_code,
}
return self.dao.update(query_request, update_request=request)
@@ -86,7 +99,6 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
Returns:
ServerResponse: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_one(query_request)
@@ -134,4 +146,267 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
List[ServerResponse]: The response
"""
query_request = request
return self.dao.get_list_page(query_request, page, page_size)
return self.dao.get_list_page(
query_request, page, page_size, ServeEntity.id.name
)
def get_prompt_template(self, prompt_type: str, target: Optional[str] = None):
request = ServeRequest()
request.prompt_type = prompt_type
request.chat_scene = target
return self.get_list(request)
def get_target_prompt(
self, target: Optional[str] = None, language: Optional[str] = None
):
logger.info(f"get_target_prompt:{target}")
request = ServeRequest()
if target:
request.chat_scene = target
if language:
request.prompt_language = language
return self.get_list(request)
def get_type_targets(self, prompt_type: str):
type = PromptType(prompt_type)
if type == PromptType.AGENT:
agent_manage = get_agent_manager()
return agent_manage.list_agents()
elif type == PromptType.SCENE:
from dbgpt.app.scene import ChatScene
return [
{"name": item.value(), "desc": item.describe()} for item in ChatScene
]
elif type == PromptType.EVALUATE:
from dbgpt.rag.evaluation.answer import LLMEvaluationMetric
return [
{"name": item.name, "desc": item.prompt_template}
for item in LLMEvaluationMetric.__subclasses__()
]
else:
return None
def get_template(self, prompt_code: str) -> Optional[PromptTemplate]:
if not prompt_code:
return None
query_request = ServeRequest(prompt_code=prompt_code)
template = self.get(query_request)
if not template:
return None
return PromptTemplate(
template=template.content,
template_scene=template.chat_scene,
input_variables=get_template_vars(template.content),
response_format=template.response_schema,
)
def load_template(
self,
prompt_type: str,
target: Optional[str] = None,
language: Optional[str] = "en",
):
logger.info(f"load_template:{prompt_type},{target}")
type = PromptType(prompt_type)
if type == PromptType.AGENT:
if not target:
raise ValueError("请选择一个Agent用来加载模版")
agent_manage = get_agent_manager()
target_agent_cls: Type[ConversableAgent] = agent_manage.get_by_name(target)
target_agent = target_agent_cls()
base_template = target_agent.prompt_template()
return PromptTemplate(
template=base_template,
input_variables=get_template_vars(base_template),
response_format=target_agent.actions[0].ai_out_schema_json,
)
elif type == PromptType.SCENE:
if not target:
raise ValueError("请选择一个场景用来加载模版")
from dbgpt._private.config import Config
cfg = Config()
from dbgpt.app.scene import AppScenePromptTemplateAdapter
try:
app_prompt: AppScenePromptTemplateAdapter = (
cfg.prompt_template_registry.get_prompt_template(
target, cfg.LANGUAGE, None
)
)
for item in app_prompt.prompt.messages:
if isinstance(item, SystemPromptTemplate):
return item.prompt
raise ValueError(f"当前场景没有找到可用的Prompt模版{target}")
except Exception as e:
raise ValueError(f"当前场景没有找到可用的Prompt模版{target}")
elif type == PromptType.EVALUATE:
if not target:
raise ValueError("请选择一个场景用来加载模版")
try:
from dbgpt.rag.evaluation.answer import (
AnswerRelevancyMetric,
LLMEvaluationMetric,
)
prompts = [
item.prompt_template
for item in LLMEvaluationMetric.__subclasses__()
if target == item.name
]
if len(prompts) == 0:
raise ValueError(f"当前场景没有找到可用的Prompt模版{target}")
prompt = prompts[0]
return PromptTemplate(
template=prompt, input_variables=get_template_vars(prompt)
)
except Exception as e:
raise ValueError(f"当前场景没有找到可用的Prompt模版{target}")
else:
return None
async def debug_prompt(self, debug_input: PromptDebugInput):
logger.info(f"debug_prompt:{debug_input}")
if not debug_input.user_input:
raise ValueError("请输入你的提问!")
try:
worker_manager = self._system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
llm_client = DefaultLLMClient(worker_manager, auto_convert_message=True)
except Exception as e:
raise ValueError("LLM prepare failed!", e)
try:
debug_messages = []
from dbgpt.core import ModelMessageRoleType
prompt = debug_input.content
prompt_vars = debug_input.input_values
type = PromptType(debug_input.prompt_type)
if type == PromptType.AGENT:
if debug_input.response_schema:
prompt_vars.update(
{
"out_schema": f"请确保按照以下json格式回复:\n{debug_input.response_schema}\n确保响应是正确的json,并且可以被Python json.loads 解析。"
}
)
elif type == PromptType.SCENE:
if debug_input.response_schema:
prompt_vars.update({"response": debug_input.response_schema})
if debug_input.input_values:
prompt = prompt.format(**prompt_vars)
debug_messages.append(
{"role": ModelMessageRoleType.SYSTEM, "content": prompt}
)
debug_messages.append(
{"role": ModelMessageRoleType.HUMAN, "content": debug_input.user_input}
)
metadata: ModelMetadata = await llm_client.get_model_metadata(
debug_input.debug_model
)
payload = {
"model": debug_input.debug_model,
"messages": debug_messages,
"temperature": debug_input.temperature,
"max_new_tokens": metadata.context_length,
"echo": metadata.ext_metadata.prompt_sep,
"stop": None,
"stop_token_ids": None,
"context_len": None,
"span_id": None,
"context": None,
}
logger.info(f"Request: \n{payload}")
span = root_tracer.start_span(
"Agent.llm_client.no_streaming_call",
metadata=self._get_span_metadata(payload),
)
payload["span_id"] = span.span_id
# if params.get("context") is not None:
# payload["context"] = ModelRequestContext(extra=params["context"])
except Exception as e:
raise ValueError("参数准备失败!" + str(e))
try:
model_request = ModelRequest(**payload)
async for output in llm_client.generate_stream(model_request.copy()): # type: ignore
yield f"data:{output.text}\n\n"
yield f"data:[DONE]\n\n"
except Exception as e:
logger.error(f"Call LLMClient error, {str(e)}, detail: {payload}")
raise ValueError(e)
finally:
span.end()
def _get_span_metadata(self, payload: Dict) -> Dict:
metadata = {k: v for k, v in payload.items()}
metadata["messages"] = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
)
return metadata
def verify_response(
self, llm_out: str, prompt_type: str, target: Optional[str] = None
):
logger.info(f"verify_response:{llm_out},{prompt_type},{target}")
type = PromptType(prompt_type)
ai_json = find_json_objects(llm_out)
if type == PromptType.AGENT:
try:
if not target:
raise ValueError("请选择一个Agent用来加载模版")
from dbgpt.agent.core import agent_manage
target_agent_cls: Type[ConversableAgent] = agent_manage.get_by_name(
target
)
target_agent = target_agent_cls()
return compare_json_properties_ex(
ai_json, json.loads(target_agent.actions[0].ai_out_schema_json)
)
except Exception as e:
raise ValueError(f"模型返回不符合[{target}]输出定义请调整prompt")
elif type == PromptType.SCENE:
if not target:
raise ValueError("请选择一个场景用来加载模版")
from dbgpt._private.config import Config
cfg = Config()
from dbgpt.app.scene import AppScenePromptTemplateAdapter
try:
app_prompt: AppScenePromptTemplateAdapter = (
cfg.prompt_template_registry.get_prompt_template(
target, cfg.LANGUAGE, None
)
)
sys_prompt = None
for item in app_prompt.prompt.messages:
if isinstance(item, SystemPromptTemplate):
sys_prompt = item.prompt
if sys_prompt:
return compare_json_properties_ex(
ai_json, json.loads(sys_prompt.response_format)
)
except Exception as e:
raise ValueError(f"当前场景没有找到可用的Prompt模版{target}")
else:
return True

View File

@@ -8,7 +8,7 @@ from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..api.endpoints import init_endpoints, router
from ..api.schemas import ServerResponse
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX
@@ -42,55 +42,6 @@ async def _create_and_validate(
assert res_obj.content == content
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_auth(client: AsyncClient):
response = await client.get("/health")
response.raise_for_status()
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]

View File

@@ -139,18 +139,3 @@ def test_service_get_list(service: Service):
for i, entity in enumerate(entities):
assert entity.sys_code == "dbgpt"
assert entity.prompt_name == f"prompt_{i}"
def test_service_get_list_by_page(service: Service):
for i in range(3):
service.create(
ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"})
)
res = service.get_list_by_page(ServeRequest(sys_code="dbgpt"), page=1, page_size=2)
assert res is not None
assert res.total_count == 3
assert res.total_pages == 2
assert len(res.items) == 2
for i, entity in enumerate(res.items):
assert entity.sys_code == "dbgpt"
assert entity.prompt_name == f"prompt_{i}"