feat(feedback): feedback upgrade

This commit is contained in:
yhjun1026
2024-08-16 11:56:25 +08:00
parent fd4991a9d2
commit b20eaacf5e
11 changed files with 174 additions and 137 deletions

View File

@@ -1,7 +1,6 @@
"""Role class for role-based conversation."""
from abc import ABC
from datetime import datetime
from typing import Dict, List, Optional
from jinja2 import Environment, Template, meta
@@ -108,7 +107,6 @@ class Role(ABC, BaseModel):
"retry_constraints": self.retry_constraints,
"examples": self.examples,
"is_retry_chat": is_retry_chat,
"now_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
param = role_params.copy()
runtime_param_names = []

View File

@@ -117,6 +117,8 @@ class SpaceQueryResponse(BaseModel):
name: Optional[str] = None
"""vector_type: vector type"""
vector_type: Optional[str] = None
"""domain_type"""
domain_type: Optional[str] = None
"""desc: description"""
desc: Optional[str] = None
"""context: context"""

View File

@@ -123,7 +123,10 @@ class KnowledgeService:
- request: KnowledgeSpaceRequest
"""
query = KnowledgeSpaceEntity(
name=request.name, vector_type=request.vector_type, owner=request.owner
id=request.id,
name=request.name,
vector_type=request.vector_type,
owner=request.owner,
)
spaces = knowledge_space_dao.get_knowledge_space(query)
space_names = [space.name for space in spaces]
@@ -295,44 +298,6 @@ class KnowledgeService:
"""
return knowledge_space_dao.get_knowledge_space_by_ids(ids)
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
"""get knowledge space
Args:
- request: KnowledgeSpaceRequest
"""
query = KnowledgeSpaceEntity(
name=request.name,
vector_type=request.vector_type,
owner=request.owner,
)
spaces = knowledge_space_dao.get_knowledge_space(query)
space_names = [space.name for space in spaces]
docs_count = knowledge_document_dao.get_knowledge_documents_count_bulk_by_ids(
space_names
)
responses = []
for space in spaces:
res = SpaceQueryResponse()
res.id = space.id
res.name = space.name
res.vector_type = space.vector_type
res.desc = space.desc
res.owner = space.owner
res.gmt_created = (
space.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
if space.gmt_created
else None
)
res.gmt_modified = (
space.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
if space.gmt_modified
else None
)
res.context = space.context
res.docs = docs_count.get(space.name, 0)
responses.append(res)
return responses
def update_knowledge_space(
self, space_id: int, space_request: KnowledgeSpaceRequest
):

View File

@@ -466,9 +466,13 @@ async def chat_completions(
"Transfer-Encoding": "chunked",
}
try:
domain_type = _parse_domain_type(dialogue)
if dialogue.chat_mode == ChatScene.ChatAgent.value():
from dbgpt.serve.agent.agents.controller import multi_agents
dialogue.ext_info.update({"model_name": dialogue.model_name})
dialogue.ext_info.update({"incremental": dialogue.incremental})
dialogue.ext_info.update({"temperature": dialogue.temperature})
return StreamingResponse(
multi_agents.app_agent_chat(
conv_uid=dialogue.conv_uid,
@@ -503,6 +507,13 @@ async def chat_completions(
headers=headers,
media_type="text/event-stream",
)
elif domain_type is not None and domain_type != "Normal":
return StreamingResponse(
chat_with_domain_flow(dialogue, domain_type),
headers=headers,
media_type="text/event-stream",
)
else:
with root_tracer.start_span(
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
@@ -672,9 +683,7 @@ def _parse_domain_type(dialogue: ConversationVo) -> Optional[str]:
KnowledgeSpaceRequest(name=space_name)
)
if len(spaces) == 0:
return Result.failed(
code="E000X", msg=f"Knowledge space {space_name} not found"
)
raise ValueError(f"Knowledge space {space_name} not found")
if spaces[0].domain_type:
return spaces[0].domain_type
else:

View File

@@ -135,6 +135,7 @@ _OPERATOR_CATEGORY_DETAIL = {
"common": _CategoryDetail("Common", "The common operator"),
"agent": _CategoryDetail("Agent", "The agent operator"),
"rag": _CategoryDetail("RAG", "The RAG operator"),
"experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"),
}

View File

@@ -40,6 +40,7 @@ from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.serve.prompt.api.endpoints import get_service
from dbgpt.serve.prompt.service import service as PromptService
from dbgpt.util.json_utils import serialize
from dbgpt.util.tracer import TracerManager
from ..db import GptsMessagesDao
from ..db.gpts_app import GptsApp, GptsAppDao, GptsAppQuery
@@ -51,6 +52,7 @@ CFG = Config()
router = APIRouter()
logger = logging.getLogger(__name__)
root_tracer: TracerManager = TracerManager()
def _build_conversation(
@@ -219,84 +221,120 @@ class MultiAgents(BaseComponent, ABC):
)
)
# init gpts memory
self.memory.init(
agent_conv_id,
enable_vis_message=enable_verbose,
history_messages=history_messages,
start_round=history_message_count,
)
# init agent memory
agent_memory = self.get_or_build_agent_memory(conv_id, gpts_name)
is_agent_flow = True
if TeamMode.AWEL_LAYOUT.value == gpt_app.team_mode:
from dbgpt.agent import AWELTeamContext
try:
task = asyncio.create_task(
multi_agents.agent_team_chat_new(
user_query,
agent_conv_id,
gpt_app,
agent_memory,
is_retry_chat,
last_speaker_name=last_speaker_name,
init_message_rounds=message_round,
**ext_info,
)
team_context: AWELTeamContext = gpt_app.team_context
if team_context.flow_category == "chat_flow":
is_agent_flow = False
if not is_agent_flow:
from dbgpt.core.awel import CommonLLMHttpRequestBody
flow_req = CommonLLMHttpRequestBody(
model=ext_info.get("model_name", None),
messages=user_query,
stream=True,
# context=flow_ctx,
# temperature=
# max_new_tokens=
# enable_vis=
conv_uid=agent_conv_id,
span_id=root_tracer.get_current_span_id(),
chat_mode=ext_info.get("chat_mode", None),
chat_param=team_context.uid,
user_name=user_code,
sys_code=sys_code,
incremental=ext_info.get("incremental", True),
)
if enable_verbose:
async for chunk in multi_agents.chat_messages(agent_conv_id):
if chunk:
try:
chunk = json.dumps(
{"vis": chunk}, default=serialize, ensure_ascii=False
)
if chunk is None or len(chunk) <= 0:
continue
resp = f"data:{chunk}\n\n"
yield task, resp, agent_conv_id
except Exception as e:
logger.exception(
f"get messages {gpts_name} Exception!" + str(e)
)
yield f"data: {str(e)}\n\n"
from dbgpt.app.openapi.api_v1.api_v1 import get_chat_flow
yield task, f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
flow_service = get_chat_flow()
async for chunk in flow_service.chat_stream_flow_str(
team_context.uid, flow_req
):
yield None, chunk, agent_conv_id
else:
# init gpts memory
self.memory.init(
agent_conv_id,
enable_vis_message=enable_verbose,
history_messages=history_messages,
start_round=history_message_count,
)
# init agent memory
agent_memory = self.get_or_build_agent_memory(conv_id, gpts_name)
else:
logger.info(f"{agent_conv_id}开启简略消息模式不进行vis协议封装获取极简流式消息直接输出")
# 开启简略消息模式不进行vis协议封装获取极简流式消息直接输出
final_message_chunk = None
async for chunk in multi_agents.chat_messages(agent_conv_id):
if chunk:
try:
if chunk is None or len(chunk) <= 0:
continue
final_message_chunk = chunk[-1]
if stream:
yield task, final_message_chunk, agent_conv_id
logger.info(
f"agent_chat_v2 executing, timestamp={int(time.time() * 1000)}"
)
except Exception as e:
logger.exception(
f"get messages {gpts_name} Exception!" + str(e)
)
final_message_chunk = str(e)
logger.info(
f"agent_chat_v2 finish, timestamp={int(time.time() * 1000)}"
try:
task = asyncio.create_task(
multi_agents.agent_team_chat_new(
user_query,
agent_conv_id,
gpt_app,
agent_memory,
is_retry_chat,
last_speaker_name=last_speaker_name,
init_message_rounds=message_round,
**ext_info,
)
)
yield task, final_message_chunk, agent_conv_id
if enable_verbose:
async for chunk in multi_agents.chat_messages(agent_conv_id):
if chunk:
try:
chunk = json.dumps(
{"vis": chunk},
default=serialize,
ensure_ascii=False,
)
if chunk is None or len(chunk) <= 0:
continue
resp = f"data:{chunk}\n\n"
yield task, resp, agent_conv_id
except Exception as e:
logger.exception(
f"get messages {gpts_name} Exception!" + str(e)
)
yield f"data: {str(e)}\n\n"
except Exception as e:
logger.exception(f"Agent chat have error!{str(e)}")
if enable_verbose:
yield task, f'data:{json.dumps({"vis": f"{str(e)}"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
yield task, f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
else:
yield task, str(e), agent_conv_id
yield task, f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
finally:
self.memory.clear(agent_conv_id)
else:
logger.info(f"{agent_conv_id}开启简略消息模式不进行vis协议封装获取极简流式消息直接输出")
# 开启简略消息模式不进行vis协议封装获取极简流式消息直接输出
final_message_chunk = None
async for chunk in multi_agents.chat_messages(agent_conv_id):
if chunk:
try:
if chunk is None or len(chunk) <= 0:
continue
final_message_chunk = chunk[-1]
if stream:
yield task, final_message_chunk, agent_conv_id
logger.info(
f"agent_chat_v2 executing, timestamp={int(time.time() * 1000)}"
)
except Exception as e:
logger.exception(
f"get messages {gpts_name} Exception!" + str(e)
)
final_message_chunk = str(e)
logger.info(
f"agent_chat_v2 finish, timestamp={int(time.time() * 1000)}"
)
yield task, final_message_chunk, agent_conv_id
except Exception as e:
logger.exception(f"Agent chat have error!{str(e)}")
if enable_verbose:
yield task, f'data:{json.dumps({"vis": f"{str(e)}"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
yield task, f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
else:
yield task, str(e), agent_conv_id
finally:
self.memory.clear(agent_conv_id)
async def app_agent_chat(
self,

View File

@@ -1,3 +1,4 @@
import logging
from functools import cache
from typing import List, Optional
@@ -19,6 +20,8 @@ from .schemas import (
ServerResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter()
# Add your API endpoints here
@@ -120,7 +123,12 @@ async def update(
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
try:
data = service.update(request)
return Result.succ(data)
except Exception as e:
logger.exception("Update prompt failed!")
return Result.failed(msg=str(e))
@router.post(

View File

@@ -102,7 +102,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
sub_chat_scene=entity.sub_chat_scene,
prompt_type=entity.prompt_type,
prompt_name=entity.prompt_name,
promt_code=entity.prompt_code,
prompt_code=entity.prompt_code,
content=entity.content,
prompt_desc=entity.prompt_desc,
user_code=entity.user_code,

View File

@@ -32,6 +32,7 @@ class KnowledgeSpaceDao(BaseDao):
knowledge_space = KnowledgeSpaceEntity(
name=space.name,
vector_type=space.vector_type,
domain_type=space.domain_type,
desc=space.desc,
owner=space.owner,
gmt_created=datetime.now(),
@@ -58,6 +59,7 @@ class KnowledgeSpaceDao(BaseDao):
return knowledge_spaces_list
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
"""Get knowledge space by query"""
session = self.get_raw_session()
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None:
@@ -72,6 +74,10 @@ class KnowledgeSpaceDao(BaseDao):
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.vector_type == query.vector_type
)
if query.domain_type is not None:
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.domain_type == query.domain_type
)
if query.desc is not None:
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.desc == query.desc
@@ -88,7 +94,6 @@ class KnowledgeSpaceDao(BaseDao):
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
)
knowledge_spaces = knowledge_spaces.order_by(
KnowledgeSpaceEntity.gmt_created.desc()
)
@@ -134,7 +139,9 @@ class KnowledgeSpaceDao(BaseDao):
T: The entity
"""
request_dict = (
request.dict() if isinstance(request, SpaceServeRequest) else request
model_to_dict(request)
if isinstance(request, SpaceServeRequest)
else request
)
entity = KnowledgeSpaceEntity(**request_dict)
return entity
@@ -173,4 +180,5 @@ class KnowledgeSpaceDao(BaseDao):
desc=entity.desc,
owner=entity.owner,
context=entity.context,
domain_type=entity.domain_type,
)

View File

@@ -171,10 +171,10 @@ class BaseDao(Generic[T, REQ, RES]):
if value is not None:
setattr(entry, key, value)
session.merge(entry)
res = self.get_one(self.to_request(entry))
if not res:
raise Exception("Update failed")
return res
# res = self.get_one(self.to_request(entry))
# if not res:
# raise Exception("Update failed")
return self.to_response(entry)
def delete(self, query_request: QUERY_SPEC) -> None:
"""Delete an entity object.

View File

@@ -91,18 +91,22 @@ class BasePackage(BaseModel):
raise ValueError("The root is required")
if root not in sys.path:
sys.path.append(root)
with pkg_resources.path(name, "__init__.py") as path:
mods = _load_modules_from_file(str(path), name, show_log=False)
all_cls = [_get_classes_from_module(m) for m in mods]
all_predicate_results = []
for m in mods:
all_predicate_results.extend(_get_from_module(m, predicates))
module_cls = []
for list_cls in all_cls:
for c in list_cls:
if issubclass(c, expected_cls):
module_cls.append(c)
return module_cls, all_predicate_results, mods
try:
with pkg_resources.path(name, "__init__.py") as path:
mods = _load_modules_from_file(str(path), name, show_log=False)
all_cls = [_get_classes_from_module(m) for m in mods]
all_predicate_results = []
for m in mods:
all_predicate_results.extend(_get_from_module(m, predicates))
module_cls = []
for list_cls in all_cls:
for c in list_cls:
if issubclass(c, expected_cls):
module_cls.append(c)
return module_cls, all_predicate_results, mods
except Exception as e:
logger.warning(f"load_module_class error!{str(e)}", e)
raise e
class FlowPackage(BasePackage):
@@ -316,7 +320,11 @@ def _load_package_from_path(path: str):
packages = _load_installed_package(path)
parsed_packages = []
for package in packages:
parsed_packages.append(_parse_package_metadata(package))
try:
parsed_packages.append(_parse_package_metadata(package))
except Exception as e:
logger.warning(f"Load package failed!{str(e)}", e)
return parsed_packages
@@ -405,7 +413,7 @@ class DBGPTsLoader(BaseComponent):
self._packages[package.name] = package
self._register_packages(package)
except Exception as e:
logger.warning(f"Load dbgpts package error: {e}")
logger.warning(f"Load dbgpts package error: {e}", e)
def get_flows(self) -> List[FlowPanel]:
"""Get the flows.