mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 19:40:13 +00:00
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result
|
||||
@@ -10,7 +12,15 @@ from dbgpt.util import PaginationResult
|
||||
|
||||
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from .schemas import ServeRequest, ServerResponse
|
||||
from .schemas import (
|
||||
PromptDebugInput,
|
||||
PromptType,
|
||||
PromptVerifyInput,
|
||||
ServeRequest,
|
||||
ServerResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -66,25 +76,6 @@ async def check_api_key(
|
||||
if request.url.path.startswith(f"/api/v1"):
|
||||
return None
|
||||
|
||||
# if service.config.api_keys:
|
||||
# api_keys = _parse_api_keys(service.config.api_keys)
|
||||
# if auth is None or (token := auth.credentials) not in api_keys:
|
||||
# raise HTTPException(
|
||||
# status_code=401,
|
||||
# detail={
|
||||
# "error": {
|
||||
# "message": "",
|
||||
# "type": "invalid_request_error",
|
||||
# "param": None,
|
||||
# "code": "invalid_api_key",
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
# return token
|
||||
# else:
|
||||
# # api_keys not set; allow all
|
||||
# return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
@@ -132,7 +123,12 @@ async def update(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.update(request))
|
||||
try:
|
||||
data = service.update(request)
|
||||
return Result.succ(data)
|
||||
except Exception as e:
|
||||
logger.exception("Update prompt failed!")
|
||||
return Result.failed(msg=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -195,6 +191,113 @@ async def query_page(
|
||||
return Result.succ(service.get_list_by_page(request, page, page_size))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/type/targets",
|
||||
response_model=Result,
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def prompt_type_targets(
|
||||
prompt_type: str = Query(
|
||||
default=PromptType.NORMAL, description="Prompt template type"
|
||||
),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result:
|
||||
"""get Prompt type
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.get_type_targets(prompt_type))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/template/load",
|
||||
response_model=Result,
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def load_template(
|
||||
prompt_type: str = Query(
|
||||
default=PromptType.NORMAL, description="Prompt template type"
|
||||
),
|
||||
target: Optional[str] = Query(
|
||||
default=None, description="The target to load the template from"
|
||||
),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result:
|
||||
"""load Prompt from target
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.load_template(prompt_type, target))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/template/debug",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def template_debug(
|
||||
debug_input: PromptDebugInput,
|
||||
service: Service = Depends(get_service),
|
||||
):
|
||||
"""test Prompt
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
try:
|
||||
return StreamingResponse(
|
||||
service.debug_prompt(
|
||||
debug_input=debug_input,
|
||||
),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.failed(msg=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/response/verify",
|
||||
response_model=Result[bool],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def response_verify(
|
||||
request: PromptVerifyInput,
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result[bool]:
|
||||
"""test Prompt
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
try:
|
||||
return Result.succ(
|
||||
service.verify_response(
|
||||
request.llm_out, request.prompt_type, request.chat_scene
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.failed(msg=str(e))
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
global global_system_app
|
||||
|
@@ -1,4 +1,5 @@
|
||||
# Define your Pydantic schemas here
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -22,7 +23,11 @@ class ServeRequest(BaseModel):
|
||||
description="The sub chat scene.",
|
||||
examples=["sub_scene_1", "sub_scene_2", "sub_scene_3"],
|
||||
)
|
||||
|
||||
prompt_code: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt code.",
|
||||
examples=["test123", "test456"],
|
||||
)
|
||||
prompt_type: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt type, either common or private.",
|
||||
@@ -51,7 +56,39 @@ class ServeRequest(BaseModel):
|
||||
"This is a prompt for data analysis expert.",
|
||||
],
|
||||
)
|
||||
response_schema: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt response schema.",
|
||||
examples=[
|
||||
"None",
|
||||
'{"xx": "123"}',
|
||||
],
|
||||
)
|
||||
input_variables: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt variables.",
|
||||
examples=[
|
||||
"display_type",
|
||||
"resources",
|
||||
],
|
||||
)
|
||||
|
||||
model: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt can use model.",
|
||||
examples=["vicuna13b", "chatgpt"],
|
||||
)
|
||||
|
||||
prompt_language: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt language.",
|
||||
examples=["en", "zh"],
|
||||
)
|
||||
user_code: Optional[str] = Field(
|
||||
None,
|
||||
description="The user id.",
|
||||
examples=[""],
|
||||
)
|
||||
user_name: Optional[str] = Field(
|
||||
None,
|
||||
description="The user name.",
|
||||
@@ -75,6 +112,11 @@ class ServerResponse(ServeRequest):
|
||||
description="The prompt id.",
|
||||
examples=[1, 2, 3],
|
||||
)
|
||||
prompt_code: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt code.",
|
||||
examples=["xxxx1", "xxxx2", "xxxx3"],
|
||||
)
|
||||
gmt_created: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt created time.",
|
||||
@@ -85,3 +127,37 @@ class ServerResponse(ServeRequest):
|
||||
description="The prompt modified time.",
|
||||
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
|
||||
)
|
||||
|
||||
|
||||
class PromptVerifyInput(ServeRequest):
|
||||
llm_out: Optional[str] = Field(
|
||||
None,
|
||||
description="The llm out of prompt.",
|
||||
)
|
||||
|
||||
|
||||
class PromptDebugInput(ServeRequest):
|
||||
input_values: Optional[dict] = Field(
|
||||
None,
|
||||
description="The prompt variables debug value.",
|
||||
)
|
||||
temperature: Optional[float] = Field(
|
||||
default=0.5,
|
||||
description="The prompt debug temperature.",
|
||||
)
|
||||
debug_model: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt debug model.",
|
||||
examples=["vicuna13b", "chatgpt"],
|
||||
)
|
||||
user_input: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt debug user input.",
|
||||
)
|
||||
|
||||
|
||||
class PromptType(Enum):
|
||||
AGENT = "Agent"
|
||||
SCENE = "Scene"
|
||||
NORMAL = "Normal"
|
||||
EVALUATE = "Evaluate"
|
||||
|
@@ -1,12 +1,12 @@
|
||||
"""This is an auto-generated model file
|
||||
You can define your own models and DAOs here
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
|
||||
|
||||
from dbgpt._private.pydantic import model_to_dict
|
||||
from dbgpt.storage.metadata import BaseDao, Model, db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
@@ -28,12 +28,14 @@ class ServeEntity(Model):
|
||||
|
||||
chat_scene = Column(String(100), comment="Chat scene")
|
||||
sub_chat_scene = Column(String(100), comment="Sub chat scene")
|
||||
prompt_code = Column(String(256), comment="Prompt Code")
|
||||
prompt_type = Column(String(100), comment="Prompt type(eg: common, private)")
|
||||
prompt_name = Column(String(256), comment="Prompt name")
|
||||
content = Column(Text, comment="Prompt content")
|
||||
input_variables = Column(
|
||||
String(1024), nullable=True, comment="Prompt input variables(split by comma))"
|
||||
)
|
||||
response_schema = Column(Text, comment="Prompt response schema")
|
||||
model = Column(
|
||||
String(128),
|
||||
nullable=True,
|
||||
@@ -50,6 +52,7 @@ class ServeEntity(Model):
|
||||
comment="Prompt format(eg: f-string, jinja2)",
|
||||
)
|
||||
prompt_desc = Column(String(512), nullable=True, comment="Prompt description")
|
||||
user_code = Column(String(128), index=True, nullable=True, comment="User code")
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
@@ -59,7 +62,7 @@ class ServeEntity(Model):
|
||||
return (
|
||||
f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', "
|
||||
f"prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',"
|
||||
f"user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
f"user_code='{self.user_code}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
)
|
||||
|
||||
|
||||
@@ -79,10 +82,10 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
Returns:
|
||||
T: The entity
|
||||
"""
|
||||
request_dict = (
|
||||
model_to_dict(request) if isinstance(request, ServeRequest) else request
|
||||
)
|
||||
request_dict = request.dict() if isinstance(request, ServeRequest) else request
|
||||
entity = ServeEntity(**request_dict)
|
||||
if not entity.prompt_code:
|
||||
entity.prompt_code = uuid.uuid4().hex
|
||||
return entity
|
||||
|
||||
def to_request(self, entity: ServeEntity) -> ServeRequest:
|
||||
@@ -99,8 +102,10 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
sub_chat_scene=entity.sub_chat_scene,
|
||||
prompt_type=entity.prompt_type,
|
||||
prompt_name=entity.prompt_name,
|
||||
prompt_code=entity.prompt_code,
|
||||
content=entity.content,
|
||||
prompt_desc=entity.prompt_desc,
|
||||
user_code=entity.user_code,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
)
|
||||
@@ -121,11 +126,16 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
id=entity.id,
|
||||
chat_scene=entity.chat_scene,
|
||||
sub_chat_scene=entity.sub_chat_scene,
|
||||
prompt_code=entity.prompt_code,
|
||||
prompt_type=entity.prompt_type,
|
||||
prompt_name=entity.prompt_name,
|
||||
content=entity.content,
|
||||
prompt_desc=entity.prompt_desc,
|
||||
user_name=entity.user_name,
|
||||
user_code=entity.user_code,
|
||||
model=entity.model,
|
||||
input_variables=entity.input_variables,
|
||||
prompt_language=entity.prompt_language,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
gmt_modified=gmt_modified_str,
|
||||
|
@@ -1,14 +1,29 @@
|
||||
from typing import List, Optional
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.agent import ConversableAgent, get_agent_manager
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.core import ModelMetadata, ModelRequest, PromptTemplate
|
||||
from dbgpt.core.interface.prompt import (
|
||||
SystemPromptTemplate,
|
||||
_get_string_template_vars,
|
||||
get_template_vars,
|
||||
)
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util.json_utils import compare_json_properties_ex, find_json_objects
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..api.schemas import PromptDebugInput, PromptType, ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The service class for Prompt"""
|
||||
@@ -27,8 +42,6 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
super().init_app(system_app)
|
||||
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
@@ -72,7 +85,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""
|
||||
# Build the query request from the request
|
||||
query_request = {
|
||||
"prompt_name": request.prompt_name,
|
||||
"prompt_code": request.prompt_code,
|
||||
"sys_code": request.sys_code,
|
||||
}
|
||||
return self.dao.update(query_request, update_request=request)
|
||||
@@ -86,7 +99,6 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
query_request = request
|
||||
return self.dao.get_one(query_request)
|
||||
@@ -134,4 +146,267 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
List[ServerResponse]: The response
|
||||
"""
|
||||
query_request = request
|
||||
return self.dao.get_list_page(query_request, page, page_size)
|
||||
return self.dao.get_list_page(
|
||||
query_request, page, page_size, ServeEntity.id.name
|
||||
)
|
||||
|
||||
def get_prompt_template(self, prompt_type: str, target: Optional[str] = None):
|
||||
request = ServeRequest()
|
||||
request.prompt_type = prompt_type
|
||||
request.chat_scene = target
|
||||
|
||||
return self.get_list(request)
|
||||
|
||||
def get_target_prompt(
|
||||
self, target: Optional[str] = None, language: Optional[str] = None
|
||||
):
|
||||
logger.info(f"get_target_prompt:{target}")
|
||||
request = ServeRequest()
|
||||
if target:
|
||||
request.chat_scene = target
|
||||
if language:
|
||||
request.prompt_language = language
|
||||
return self.get_list(request)
|
||||
|
||||
def get_type_targets(self, prompt_type: str):
|
||||
type = PromptType(prompt_type)
|
||||
if type == PromptType.AGENT:
|
||||
agent_manage = get_agent_manager()
|
||||
return agent_manage.list_agents()
|
||||
|
||||
elif type == PromptType.SCENE:
|
||||
from dbgpt.app.scene import ChatScene
|
||||
|
||||
return [
|
||||
{"name": item.value(), "desc": item.describe()} for item in ChatScene
|
||||
]
|
||||
elif type == PromptType.EVALUATE:
|
||||
from dbgpt.rag.evaluation.answer import LLMEvaluationMetric
|
||||
|
||||
return [
|
||||
{"name": item.name, "desc": item.prompt_template}
|
||||
for item in LLMEvaluationMetric.__subclasses__()
|
||||
]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_template(self, prompt_code: str) -> Optional[PromptTemplate]:
|
||||
if not prompt_code:
|
||||
return None
|
||||
query_request = ServeRequest(prompt_code=prompt_code)
|
||||
template = self.get(query_request)
|
||||
if not template:
|
||||
return None
|
||||
return PromptTemplate(
|
||||
template=template.content,
|
||||
template_scene=template.chat_scene,
|
||||
input_variables=get_template_vars(template.content),
|
||||
response_format=template.response_schema,
|
||||
)
|
||||
|
||||
def load_template(
|
||||
self,
|
||||
prompt_type: str,
|
||||
target: Optional[str] = None,
|
||||
language: Optional[str] = "en",
|
||||
):
|
||||
logger.info(f"load_template:{prompt_type},{target}")
|
||||
type = PromptType(prompt_type)
|
||||
if type == PromptType.AGENT:
|
||||
if not target:
|
||||
raise ValueError("请选择一个Agent用来加载模版")
|
||||
agent_manage = get_agent_manager()
|
||||
target_agent_cls: Type[ConversableAgent] = agent_manage.get_by_name(target)
|
||||
target_agent = target_agent_cls()
|
||||
base_template = target_agent.prompt_template()
|
||||
|
||||
return PromptTemplate(
|
||||
template=base_template,
|
||||
input_variables=get_template_vars(base_template),
|
||||
response_format=target_agent.actions[0].ai_out_schema_json,
|
||||
)
|
||||
elif type == PromptType.SCENE:
|
||||
if not target:
|
||||
raise ValueError("请选择一个场景用来加载模版")
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
cfg = Config()
|
||||
from dbgpt.app.scene import AppScenePromptTemplateAdapter
|
||||
|
||||
try:
|
||||
app_prompt: AppScenePromptTemplateAdapter = (
|
||||
cfg.prompt_template_registry.get_prompt_template(
|
||||
target, cfg.LANGUAGE, None
|
||||
)
|
||||
)
|
||||
for item in app_prompt.prompt.messages:
|
||||
if isinstance(item, SystemPromptTemplate):
|
||||
return item.prompt
|
||||
raise ValueError(f"当前场景没有找到可用的Prompt模版,{target}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"当前场景没有找到可用的Prompt模版,{target}")
|
||||
elif type == PromptType.EVALUATE:
|
||||
if not target:
|
||||
raise ValueError("请选择一个场景用来加载模版")
|
||||
try:
|
||||
from dbgpt.rag.evaluation.answer import (
|
||||
AnswerRelevancyMetric,
|
||||
LLMEvaluationMetric,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
item.prompt_template
|
||||
for item in LLMEvaluationMetric.__subclasses__()
|
||||
if target == item.name
|
||||
]
|
||||
if len(prompts) == 0:
|
||||
raise ValueError(f"当前场景没有找到可用的Prompt模版,{target}")
|
||||
prompt = prompts[0]
|
||||
return PromptTemplate(
|
||||
template=prompt, input_variables=get_template_vars(prompt)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"当前场景没有找到可用的Prompt模版,{target}")
|
||||
else:
|
||||
return None
|
||||
|
||||
async def debug_prompt(self, debug_input: PromptDebugInput):
|
||||
logger.info(f"debug_prompt:{debug_input}")
|
||||
if not debug_input.user_input:
|
||||
raise ValueError("请输入你的提问!")
|
||||
try:
|
||||
worker_manager = self._system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
llm_client = DefaultLLMClient(worker_manager, auto_convert_message=True)
|
||||
except Exception as e:
|
||||
raise ValueError("LLM prepare failed!", e)
|
||||
|
||||
try:
|
||||
debug_messages = []
|
||||
from dbgpt.core import ModelMessageRoleType
|
||||
|
||||
prompt = debug_input.content
|
||||
|
||||
prompt_vars = debug_input.input_values
|
||||
type = PromptType(debug_input.prompt_type)
|
||||
if type == PromptType.AGENT:
|
||||
if debug_input.response_schema:
|
||||
prompt_vars.update(
|
||||
{
|
||||
"out_schema": f"请确保按照以下json格式回复:\n{debug_input.response_schema}\n确保响应是正确的json,并且可以被Python json.loads 解析。"
|
||||
}
|
||||
)
|
||||
elif type == PromptType.SCENE:
|
||||
if debug_input.response_schema:
|
||||
prompt_vars.update({"response": debug_input.response_schema})
|
||||
|
||||
if debug_input.input_values:
|
||||
prompt = prompt.format(**prompt_vars)
|
||||
|
||||
debug_messages.append(
|
||||
{"role": ModelMessageRoleType.SYSTEM, "content": prompt}
|
||||
)
|
||||
debug_messages.append(
|
||||
{"role": ModelMessageRoleType.HUMAN, "content": debug_input.user_input}
|
||||
)
|
||||
metadata: ModelMetadata = await llm_client.get_model_metadata(
|
||||
debug_input.debug_model
|
||||
)
|
||||
payload = {
|
||||
"model": debug_input.debug_model,
|
||||
"messages": debug_messages,
|
||||
"temperature": debug_input.temperature,
|
||||
"max_new_tokens": metadata.context_length,
|
||||
"echo": metadata.ext_metadata.prompt_sep,
|
||||
"stop": None,
|
||||
"stop_token_ids": None,
|
||||
"context_len": None,
|
||||
"span_id": None,
|
||||
"context": None,
|
||||
}
|
||||
logger.info(f"Request: \n{payload}")
|
||||
span = root_tracer.start_span(
|
||||
"Agent.llm_client.no_streaming_call",
|
||||
metadata=self._get_span_metadata(payload),
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
|
||||
# if params.get("context") is not None:
|
||||
# payload["context"] = ModelRequestContext(extra=params["context"])
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError("参数准备失败!" + str(e))
|
||||
|
||||
try:
|
||||
model_request = ModelRequest(**payload)
|
||||
|
||||
async for output in llm_client.generate_stream(model_request.copy()): # type: ignore
|
||||
yield f"data:{output.text}\n\n"
|
||||
yield f"data:[DONE]\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Call LLMClient error, {str(e)}, detail: {payload}")
|
||||
raise ValueError(e)
|
||||
finally:
|
||||
span.end()
|
||||
|
||||
def _get_span_metadata(self, payload: Dict) -> Dict:
|
||||
metadata = {k: v for k, v in payload.items()}
|
||||
|
||||
metadata["messages"] = list(
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
|
||||
)
|
||||
return metadata
|
||||
|
||||
def verify_response(
|
||||
self, llm_out: str, prompt_type: str, target: Optional[str] = None
|
||||
):
|
||||
logger.info(f"verify_response:{llm_out},{prompt_type},{target}")
|
||||
type = PromptType(prompt_type)
|
||||
ai_json = find_json_objects(llm_out)
|
||||
if type == PromptType.AGENT:
|
||||
try:
|
||||
if not target:
|
||||
raise ValueError("请选择一个Agent用来加载模版")
|
||||
from dbgpt.agent.core import agent_manage
|
||||
|
||||
target_agent_cls: Type[ConversableAgent] = agent_manage.get_by_name(
|
||||
target
|
||||
)
|
||||
target_agent = target_agent_cls()
|
||||
|
||||
return compare_json_properties_ex(
|
||||
ai_json, json.loads(target_agent.actions[0].ai_out_schema_json)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"模型返回不符合[{target}]输出定义,请调整prompt!")
|
||||
|
||||
elif type == PromptType.SCENE:
|
||||
if not target:
|
||||
raise ValueError("请选择一个场景用来加载模版")
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
cfg = Config()
|
||||
from dbgpt.app.scene import AppScenePromptTemplateAdapter
|
||||
|
||||
try:
|
||||
app_prompt: AppScenePromptTemplateAdapter = (
|
||||
cfg.prompt_template_registry.get_prompt_template(
|
||||
target, cfg.LANGUAGE, None
|
||||
)
|
||||
)
|
||||
|
||||
sys_prompt = None
|
||||
for item in app_prompt.prompt.messages:
|
||||
if isinstance(item, SystemPromptTemplate):
|
||||
sys_prompt = item.prompt
|
||||
if sys_prompt:
|
||||
return compare_json_properties_ex(
|
||||
ai_json, json.loads(sys_prompt.response_format)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"当前场景没有找到可用的Prompt模版,{target}")
|
||||
else:
|
||||
return True
|
||||
|
@@ -8,7 +8,7 @@ from dbgpt.storage.metadata import db
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..api.endpoints import init_endpoints, router
|
||||
from ..api.schemas import ServerResponse
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
|
||||
|
||||
@@ -42,55 +42,6 @@ async def _create_and_validate(
|
||||
assert res_obj.content == content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, asystem_app, has_auth",
|
||||
[
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "test_token1",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "error_token",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
False,
|
||||
),
|
||||
],
|
||||
indirect=["client", "asystem_app"],
|
||||
)
|
||||
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
|
||||
response = await client.get("/test_auth")
|
||||
if has_auth:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_auth(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
response.raise_for_status()
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
|
@@ -139,18 +139,3 @@ def test_service_get_list(service: Service):
|
||||
for i, entity in enumerate(entities):
|
||||
assert entity.sys_code == "dbgpt"
|
||||
assert entity.prompt_name == f"prompt_{i}"
|
||||
|
||||
|
||||
def test_service_get_list_by_page(service: Service):
|
||||
for i in range(3):
|
||||
service.create(
|
||||
ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"})
|
||||
)
|
||||
res = service.get_list_by_page(ServeRequest(sys_code="dbgpt"), page=1, page_size=2)
|
||||
assert res is not None
|
||||
assert res.total_count == 3
|
||||
assert res.total_pages == 2
|
||||
assert len(res.items) == 2
|
||||
for i, entity in enumerate(res.items):
|
||||
assert entity.sys_code == "dbgpt"
|
||||
assert entity.prompt_name == f"prompt_{i}"
|
||||
|
Reference in New Issue
Block a user