mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 21:51:25 +00:00
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,14 +1,29 @@
|
||||
from typing import List, Optional
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.agent import ConversableAgent, get_agent_manager
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.core import ModelMetadata, ModelRequest, PromptTemplate
|
||||
from dbgpt.core.interface.prompt import (
|
||||
SystemPromptTemplate,
|
||||
_get_string_template_vars,
|
||||
get_template_vars,
|
||||
)
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util.json_utils import compare_json_properties_ex, find_json_objects
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..api.schemas import PromptDebugInput, PromptType, ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The service class for Prompt"""
|
||||
@@ -27,8 +42,6 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
super().init_app(system_app)
|
||||
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
@@ -72,7 +85,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""
|
||||
# Build the query request from the request
|
||||
query_request = {
|
||||
"prompt_name": request.prompt_name,
|
||||
"prompt_code": request.prompt_code,
|
||||
"sys_code": request.sys_code,
|
||||
}
|
||||
return self.dao.update(query_request, update_request=request)
|
||||
@@ -86,7 +99,6 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
query_request = request
|
||||
return self.dao.get_one(query_request)
|
||||
@@ -134,4 +146,267 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
List[ServerResponse]: The response
|
||||
"""
|
||||
query_request = request
|
||||
return self.dao.get_list_page(query_request, page, page_size)
|
||||
return self.dao.get_list_page(
|
||||
query_request, page, page_size, ServeEntity.id.name
|
||||
)
|
||||
|
||||
def get_prompt_template(self, prompt_type: str, target: Optional[str] = None):
|
||||
request = ServeRequest()
|
||||
request.prompt_type = prompt_type
|
||||
request.chat_scene = target
|
||||
|
||||
return self.get_list(request)
|
||||
|
||||
def get_target_prompt(
|
||||
self, target: Optional[str] = None, language: Optional[str] = None
|
||||
):
|
||||
logger.info(f"get_target_prompt:{target}")
|
||||
request = ServeRequest()
|
||||
if target:
|
||||
request.chat_scene = target
|
||||
if language:
|
||||
request.prompt_language = language
|
||||
return self.get_list(request)
|
||||
|
||||
def get_type_targets(self, prompt_type: str):
|
||||
type = PromptType(prompt_type)
|
||||
if type == PromptType.AGENT:
|
||||
agent_manage = get_agent_manager()
|
||||
return agent_manage.list_agents()
|
||||
|
||||
elif type == PromptType.SCENE:
|
||||
from dbgpt.app.scene import ChatScene
|
||||
|
||||
return [
|
||||
{"name": item.value(), "desc": item.describe()} for item in ChatScene
|
||||
]
|
||||
elif type == PromptType.EVALUATE:
|
||||
from dbgpt.rag.evaluation.answer import LLMEvaluationMetric
|
||||
|
||||
return [
|
||||
{"name": item.name, "desc": item.prompt_template}
|
||||
for item in LLMEvaluationMetric.__subclasses__()
|
||||
]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_template(self, prompt_code: str) -> Optional[PromptTemplate]:
|
||||
if not prompt_code:
|
||||
return None
|
||||
query_request = ServeRequest(prompt_code=prompt_code)
|
||||
template = self.get(query_request)
|
||||
if not template:
|
||||
return None
|
||||
return PromptTemplate(
|
||||
template=template.content,
|
||||
template_scene=template.chat_scene,
|
||||
input_variables=get_template_vars(template.content),
|
||||
response_format=template.response_schema,
|
||||
)
|
||||
|
||||
def load_template(
|
||||
self,
|
||||
prompt_type: str,
|
||||
target: Optional[str] = None,
|
||||
language: Optional[str] = "en",
|
||||
):
|
||||
logger.info(f"load_template:{prompt_type},{target}")
|
||||
type = PromptType(prompt_type)
|
||||
if type == PromptType.AGENT:
|
||||
if not target:
|
||||
raise ValueError("请选择一个Agent用来加载模版")
|
||||
agent_manage = get_agent_manager()
|
||||
target_agent_cls: Type[ConversableAgent] = agent_manage.get_by_name(target)
|
||||
target_agent = target_agent_cls()
|
||||
base_template = target_agent.prompt_template()
|
||||
|
||||
return PromptTemplate(
|
||||
template=base_template,
|
||||
input_variables=get_template_vars(base_template),
|
||||
response_format=target_agent.actions[0].ai_out_schema_json,
|
||||
)
|
||||
elif type == PromptType.SCENE:
|
||||
if not target:
|
||||
raise ValueError("请选择一个场景用来加载模版")
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
cfg = Config()
|
||||
from dbgpt.app.scene import AppScenePromptTemplateAdapter
|
||||
|
||||
try:
|
||||
app_prompt: AppScenePromptTemplateAdapter = (
|
||||
cfg.prompt_template_registry.get_prompt_template(
|
||||
target, cfg.LANGUAGE, None
|
||||
)
|
||||
)
|
||||
for item in app_prompt.prompt.messages:
|
||||
if isinstance(item, SystemPromptTemplate):
|
||||
return item.prompt
|
||||
raise ValueError(f"当前场景没有找到可用的Prompt模版,{target}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"当前场景没有找到可用的Prompt模版,{target}")
|
||||
elif type == PromptType.EVALUATE:
|
||||
if not target:
|
||||
raise ValueError("请选择一个场景用来加载模版")
|
||||
try:
|
||||
from dbgpt.rag.evaluation.answer import (
|
||||
AnswerRelevancyMetric,
|
||||
LLMEvaluationMetric,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
item.prompt_template
|
||||
for item in LLMEvaluationMetric.__subclasses__()
|
||||
if target == item.name
|
||||
]
|
||||
if len(prompts) == 0:
|
||||
raise ValueError(f"当前场景没有找到可用的Prompt模版,{target}")
|
||||
prompt = prompts[0]
|
||||
return PromptTemplate(
|
||||
template=prompt, input_variables=get_template_vars(prompt)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"当前场景没有找到可用的Prompt模版,{target}")
|
||||
else:
|
||||
return None
|
||||
|
||||
async def debug_prompt(self, debug_input: PromptDebugInput):
|
||||
logger.info(f"debug_prompt:{debug_input}")
|
||||
if not debug_input.user_input:
|
||||
raise ValueError("请输入你的提问!")
|
||||
try:
|
||||
worker_manager = self._system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
llm_client = DefaultLLMClient(worker_manager, auto_convert_message=True)
|
||||
except Exception as e:
|
||||
raise ValueError("LLM prepare failed!", e)
|
||||
|
||||
try:
|
||||
debug_messages = []
|
||||
from dbgpt.core import ModelMessageRoleType
|
||||
|
||||
prompt = debug_input.content
|
||||
|
||||
prompt_vars = debug_input.input_values
|
||||
type = PromptType(debug_input.prompt_type)
|
||||
if type == PromptType.AGENT:
|
||||
if debug_input.response_schema:
|
||||
prompt_vars.update(
|
||||
{
|
||||
"out_schema": f"请确保按照以下json格式回复:\n{debug_input.response_schema}\n确保响应是正确的json,并且可以被Python json.loads 解析。"
|
||||
}
|
||||
)
|
||||
elif type == PromptType.SCENE:
|
||||
if debug_input.response_schema:
|
||||
prompt_vars.update({"response": debug_input.response_schema})
|
||||
|
||||
if debug_input.input_values:
|
||||
prompt = prompt.format(**prompt_vars)
|
||||
|
||||
debug_messages.append(
|
||||
{"role": ModelMessageRoleType.SYSTEM, "content": prompt}
|
||||
)
|
||||
debug_messages.append(
|
||||
{"role": ModelMessageRoleType.HUMAN, "content": debug_input.user_input}
|
||||
)
|
||||
metadata: ModelMetadata = await llm_client.get_model_metadata(
|
||||
debug_input.debug_model
|
||||
)
|
||||
payload = {
|
||||
"model": debug_input.debug_model,
|
||||
"messages": debug_messages,
|
||||
"temperature": debug_input.temperature,
|
||||
"max_new_tokens": metadata.context_length,
|
||||
"echo": metadata.ext_metadata.prompt_sep,
|
||||
"stop": None,
|
||||
"stop_token_ids": None,
|
||||
"context_len": None,
|
||||
"span_id": None,
|
||||
"context": None,
|
||||
}
|
||||
logger.info(f"Request: \n{payload}")
|
||||
span = root_tracer.start_span(
|
||||
"Agent.llm_client.no_streaming_call",
|
||||
metadata=self._get_span_metadata(payload),
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
|
||||
# if params.get("context") is not None:
|
||||
# payload["context"] = ModelRequestContext(extra=params["context"])
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError("参数准备失败!" + str(e))
|
||||
|
||||
try:
|
||||
model_request = ModelRequest(**payload)
|
||||
|
||||
async for output in llm_client.generate_stream(model_request.copy()): # type: ignore
|
||||
yield f"data:{output.text}\n\n"
|
||||
yield f"data:[DONE]\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Call LLMClient error, {str(e)}, detail: {payload}")
|
||||
raise ValueError(e)
|
||||
finally:
|
||||
span.end()
|
||||
|
||||
def _get_span_metadata(self, payload: Dict) -> Dict:
|
||||
metadata = {k: v for k, v in payload.items()}
|
||||
|
||||
metadata["messages"] = list(
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
|
||||
)
|
||||
return metadata
|
||||
|
||||
def verify_response(
|
||||
self, llm_out: str, prompt_type: str, target: Optional[str] = None
|
||||
):
|
||||
logger.info(f"verify_response:{llm_out},{prompt_type},{target}")
|
||||
type = PromptType(prompt_type)
|
||||
ai_json = find_json_objects(llm_out)
|
||||
if type == PromptType.AGENT:
|
||||
try:
|
||||
if not target:
|
||||
raise ValueError("请选择一个Agent用来加载模版")
|
||||
from dbgpt.agent.core import agent_manage
|
||||
|
||||
target_agent_cls: Type[ConversableAgent] = agent_manage.get_by_name(
|
||||
target
|
||||
)
|
||||
target_agent = target_agent_cls()
|
||||
|
||||
return compare_json_properties_ex(
|
||||
ai_json, json.loads(target_agent.actions[0].ai_out_schema_json)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"模型返回不符合[{target}]输出定义,请调整prompt!")
|
||||
|
||||
elif type == PromptType.SCENE:
|
||||
if not target:
|
||||
raise ValueError("请选择一个场景用来加载模版")
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
cfg = Config()
|
||||
from dbgpt.app.scene import AppScenePromptTemplateAdapter
|
||||
|
||||
try:
|
||||
app_prompt: AppScenePromptTemplateAdapter = (
|
||||
cfg.prompt_template_registry.get_prompt_template(
|
||||
target, cfg.LANGUAGE, None
|
||||
)
|
||||
)
|
||||
|
||||
sys_prompt = None
|
||||
for item in app_prompt.prompt.messages:
|
||||
if isinstance(item, SystemPromptTemplate):
|
||||
sys_prompt = item.prompt
|
||||
if sys_prompt:
|
||||
return compare_json_properties_ex(
|
||||
ai_json, json.loads(sys_prompt.response_format)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"当前场景没有找到可用的Prompt模版,{target}")
|
||||
else:
|
||||
return True
|
||||
|
Reference in New Issue
Block a user