mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 13:58:58 +00:00
feat(feedback): feedback upgrade
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
"""Role class for role-based conversation."""
|
||||
|
||||
from abc import ABC
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from jinja2 import Environment, Template, meta
|
||||
@@ -108,7 +107,6 @@ class Role(ABC, BaseModel):
|
||||
"retry_constraints": self.retry_constraints,
|
||||
"examples": self.examples,
|
||||
"is_retry_chat": is_retry_chat,
|
||||
"now_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
param = role_params.copy()
|
||||
runtime_param_names = []
|
||||
|
@@ -117,6 +117,8 @@ class SpaceQueryResponse(BaseModel):
|
||||
name: Optional[str] = None
|
||||
"""vector_type: vector type"""
|
||||
vector_type: Optional[str] = None
|
||||
"""domain_type"""
|
||||
domain_type: Optional[str] = None
|
||||
"""desc: description"""
|
||||
desc: Optional[str] = None
|
||||
"""context: context"""
|
||||
|
@@ -123,7 +123,10 @@ class KnowledgeService:
|
||||
- request: KnowledgeSpaceRequest
|
||||
"""
|
||||
query = KnowledgeSpaceEntity(
|
||||
name=request.name, vector_type=request.vector_type, owner=request.owner
|
||||
id=request.id,
|
||||
name=request.name,
|
||||
vector_type=request.vector_type,
|
||||
owner=request.owner,
|
||||
)
|
||||
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||
space_names = [space.name for space in spaces]
|
||||
@@ -295,44 +298,6 @@ class KnowledgeService:
|
||||
"""
|
||||
return knowledge_space_dao.get_knowledge_space_by_ids(ids)
|
||||
|
||||
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
|
||||
"""get knowledge space
|
||||
Args:
|
||||
- request: KnowledgeSpaceRequest
|
||||
"""
|
||||
query = KnowledgeSpaceEntity(
|
||||
name=request.name,
|
||||
vector_type=request.vector_type,
|
||||
owner=request.owner,
|
||||
)
|
||||
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||
space_names = [space.name for space in spaces]
|
||||
docs_count = knowledge_document_dao.get_knowledge_documents_count_bulk_by_ids(
|
||||
space_names
|
||||
)
|
||||
responses = []
|
||||
for space in spaces:
|
||||
res = SpaceQueryResponse()
|
||||
res.id = space.id
|
||||
res.name = space.name
|
||||
res.vector_type = space.vector_type
|
||||
res.desc = space.desc
|
||||
res.owner = space.owner
|
||||
res.gmt_created = (
|
||||
space.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if space.gmt_created
|
||||
else None
|
||||
)
|
||||
res.gmt_modified = (
|
||||
space.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if space.gmt_modified
|
||||
else None
|
||||
)
|
||||
res.context = space.context
|
||||
res.docs = docs_count.get(space.name, 0)
|
||||
responses.append(res)
|
||||
return responses
|
||||
|
||||
def update_knowledge_space(
|
||||
self, space_id: int, space_request: KnowledgeSpaceRequest
|
||||
):
|
||||
|
@@ -466,9 +466,13 @@ async def chat_completions(
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
try:
|
||||
domain_type = _parse_domain_type(dialogue)
|
||||
if dialogue.chat_mode == ChatScene.ChatAgent.value():
|
||||
from dbgpt.serve.agent.agents.controller import multi_agents
|
||||
|
||||
dialogue.ext_info.update({"model_name": dialogue.model_name})
|
||||
dialogue.ext_info.update({"incremental": dialogue.incremental})
|
||||
dialogue.ext_info.update({"temperature": dialogue.temperature})
|
||||
return StreamingResponse(
|
||||
multi_agents.app_agent_chat(
|
||||
conv_uid=dialogue.conv_uid,
|
||||
@@ -503,6 +507,13 @@ async def chat_completions(
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
elif domain_type is not None and domain_type != "Normal":
|
||||
return StreamingResponse(
|
||||
chat_with_domain_flow(dialogue, domain_type),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
else:
|
||||
with root_tracer.start_span(
|
||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
|
||||
@@ -672,9 +683,7 @@ def _parse_domain_type(dialogue: ConversationVo) -> Optional[str]:
|
||||
KnowledgeSpaceRequest(name=space_name)
|
||||
)
|
||||
if len(spaces) == 0:
|
||||
return Result.failed(
|
||||
code="E000X", msg=f"Knowledge space {space_name} not found"
|
||||
)
|
||||
raise ValueError(f"Knowledge space {space_name} not found")
|
||||
if spaces[0].domain_type:
|
||||
return spaces[0].domain_type
|
||||
else:
|
||||
|
@@ -135,6 +135,7 @@ _OPERATOR_CATEGORY_DETAIL = {
|
||||
"common": _CategoryDetail("Common", "The common operator"),
|
||||
"agent": _CategoryDetail("Agent", "The agent operator"),
|
||||
"rag": _CategoryDetail("RAG", "The RAG operator"),
|
||||
"experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -40,6 +40,7 @@ from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||
from dbgpt.serve.prompt.api.endpoints import get_service
|
||||
from dbgpt.serve.prompt.service import service as PromptService
|
||||
from dbgpt.util.json_utils import serialize
|
||||
from dbgpt.util.tracer import TracerManager
|
||||
|
||||
from ..db import GptsMessagesDao
|
||||
from ..db.gpts_app import GptsApp, GptsAppDao, GptsAppQuery
|
||||
@@ -51,6 +52,7 @@ CFG = Config()
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
root_tracer: TracerManager = TracerManager()
|
||||
|
||||
|
||||
def _build_conversation(
|
||||
@@ -219,84 +221,120 @@ class MultiAgents(BaseComponent, ABC):
|
||||
)
|
||||
)
|
||||
|
||||
# init gpts memory
|
||||
self.memory.init(
|
||||
agent_conv_id,
|
||||
enable_vis_message=enable_verbose,
|
||||
history_messages=history_messages,
|
||||
start_round=history_message_count,
|
||||
)
|
||||
# init agent memory
|
||||
agent_memory = self.get_or_build_agent_memory(conv_id, gpts_name)
|
||||
is_agent_flow = True
|
||||
if TeamMode.AWEL_LAYOUT.value == gpt_app.team_mode:
|
||||
from dbgpt.agent import AWELTeamContext
|
||||
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
multi_agents.agent_team_chat_new(
|
||||
user_query,
|
||||
agent_conv_id,
|
||||
gpt_app,
|
||||
agent_memory,
|
||||
is_retry_chat,
|
||||
last_speaker_name=last_speaker_name,
|
||||
init_message_rounds=message_round,
|
||||
**ext_info,
|
||||
)
|
||||
team_context: AWELTeamContext = gpt_app.team_context
|
||||
if team_context.flow_category == "chat_flow":
|
||||
is_agent_flow = False
|
||||
if not is_agent_flow:
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
||||
|
||||
flow_req = CommonLLMHttpRequestBody(
|
||||
model=ext_info.get("model_name", None),
|
||||
messages=user_query,
|
||||
stream=True,
|
||||
# context=flow_ctx,
|
||||
# temperature=
|
||||
# max_new_tokens=
|
||||
# enable_vis=
|
||||
conv_uid=agent_conv_id,
|
||||
span_id=root_tracer.get_current_span_id(),
|
||||
chat_mode=ext_info.get("chat_mode", None),
|
||||
chat_param=team_context.uid,
|
||||
user_name=user_code,
|
||||
sys_code=sys_code,
|
||||
incremental=ext_info.get("incremental", True),
|
||||
)
|
||||
if enable_verbose:
|
||||
async for chunk in multi_agents.chat_messages(agent_conv_id):
|
||||
if chunk:
|
||||
try:
|
||||
chunk = json.dumps(
|
||||
{"vis": chunk}, default=serialize, ensure_ascii=False
|
||||
)
|
||||
if chunk is None or len(chunk) <= 0:
|
||||
continue
|
||||
resp = f"data:{chunk}\n\n"
|
||||
yield task, resp, agent_conv_id
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"get messages {gpts_name} Exception!" + str(e)
|
||||
)
|
||||
yield f"data: {str(e)}\n\n"
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import get_chat_flow
|
||||
|
||||
yield task, f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
|
||||
flow_service = get_chat_flow()
|
||||
async for chunk in flow_service.chat_stream_flow_str(
|
||||
team_context.uid, flow_req
|
||||
):
|
||||
yield None, chunk, agent_conv_id
|
||||
else:
|
||||
# init gpts memory
|
||||
self.memory.init(
|
||||
agent_conv_id,
|
||||
enable_vis_message=enable_verbose,
|
||||
history_messages=history_messages,
|
||||
start_round=history_message_count,
|
||||
)
|
||||
# init agent memory
|
||||
agent_memory = self.get_or_build_agent_memory(conv_id, gpts_name)
|
||||
|
||||
else:
|
||||
logger.info(f"{agent_conv_id}开启简略消息模式,不进行vis协议封装,获取极简流式消息直接输出")
|
||||
# 开启简略消息模式,不进行vis协议封装,获取极简流式消息直接输出
|
||||
final_message_chunk = None
|
||||
async for chunk in multi_agents.chat_messages(agent_conv_id):
|
||||
if chunk:
|
||||
try:
|
||||
if chunk is None or len(chunk) <= 0:
|
||||
continue
|
||||
final_message_chunk = chunk[-1]
|
||||
if stream:
|
||||
yield task, final_message_chunk, agent_conv_id
|
||||
logger.info(
|
||||
f"agent_chat_v2 executing, timestamp={int(time.time() * 1000)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"get messages {gpts_name} Exception!" + str(e)
|
||||
)
|
||||
final_message_chunk = str(e)
|
||||
|
||||
logger.info(
|
||||
f"agent_chat_v2 finish, timestamp={int(time.time() * 1000)}"
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
multi_agents.agent_team_chat_new(
|
||||
user_query,
|
||||
agent_conv_id,
|
||||
gpt_app,
|
||||
agent_memory,
|
||||
is_retry_chat,
|
||||
last_speaker_name=last_speaker_name,
|
||||
init_message_rounds=message_round,
|
||||
**ext_info,
|
||||
)
|
||||
)
|
||||
yield task, final_message_chunk, agent_conv_id
|
||||
if enable_verbose:
|
||||
async for chunk in multi_agents.chat_messages(agent_conv_id):
|
||||
if chunk:
|
||||
try:
|
||||
chunk = json.dumps(
|
||||
{"vis": chunk},
|
||||
default=serialize,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if chunk is None or len(chunk) <= 0:
|
||||
continue
|
||||
resp = f"data:{chunk}\n\n"
|
||||
yield task, resp, agent_conv_id
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"get messages {gpts_name} Exception!" + str(e)
|
||||
)
|
||||
yield f"data: {str(e)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Agent chat have error!{str(e)}")
|
||||
if enable_verbose:
|
||||
yield task, f'data:{json.dumps({"vis": f"{str(e)}"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
|
||||
yield task, f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
|
||||
else:
|
||||
yield task, str(e), agent_conv_id
|
||||
yield task, f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
|
||||
|
||||
finally:
|
||||
self.memory.clear(agent_conv_id)
|
||||
else:
|
||||
logger.info(f"{agent_conv_id}开启简略消息模式,不进行vis协议封装,获取极简流式消息直接输出")
|
||||
# 开启简略消息模式,不进行vis协议封装,获取极简流式消息直接输出
|
||||
final_message_chunk = None
|
||||
async for chunk in multi_agents.chat_messages(agent_conv_id):
|
||||
if chunk:
|
||||
try:
|
||||
if chunk is None or len(chunk) <= 0:
|
||||
continue
|
||||
final_message_chunk = chunk[-1]
|
||||
if stream:
|
||||
yield task, final_message_chunk, agent_conv_id
|
||||
logger.info(
|
||||
f"agent_chat_v2 executing, timestamp={int(time.time() * 1000)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"get messages {gpts_name} Exception!" + str(e)
|
||||
)
|
||||
final_message_chunk = str(e)
|
||||
|
||||
logger.info(
|
||||
f"agent_chat_v2 finish, timestamp={int(time.time() * 1000)}"
|
||||
)
|
||||
yield task, final_message_chunk, agent_conv_id
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Agent chat have error!{str(e)}")
|
||||
if enable_verbose:
|
||||
yield task, f'data:{json.dumps({"vis": f"{str(e)}"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
|
||||
yield task, f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n', agent_conv_id
|
||||
else:
|
||||
yield task, str(e), agent_conv_id
|
||||
|
||||
finally:
|
||||
self.memory.clear(agent_conv_id)
|
||||
|
||||
async def app_agent_chat(
|
||||
self,
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -19,6 +20,8 @@ from .schemas import (
|
||||
ServerResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Add your API endpoints here
|
||||
@@ -120,7 +123,12 @@ async def update(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.update(request))
|
||||
try:
|
||||
data = service.update(request)
|
||||
return Result.succ(data)
|
||||
except Exception as e:
|
||||
logger.exception("Update prompt failed!")
|
||||
return Result.failed(msg=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
|
@@ -102,7 +102,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
sub_chat_scene=entity.sub_chat_scene,
|
||||
prompt_type=entity.prompt_type,
|
||||
prompt_name=entity.prompt_name,
|
||||
promt_code=entity.prompt_code,
|
||||
prompt_code=entity.prompt_code,
|
||||
content=entity.content,
|
||||
prompt_desc=entity.prompt_desc,
|
||||
user_code=entity.user_code,
|
||||
|
@@ -32,6 +32,7 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
knowledge_space = KnowledgeSpaceEntity(
|
||||
name=space.name,
|
||||
vector_type=space.vector_type,
|
||||
domain_type=space.domain_type,
|
||||
desc=space.desc,
|
||||
owner=space.owner,
|
||||
gmt_created=datetime.now(),
|
||||
@@ -58,6 +59,7 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
return knowledge_spaces_list
|
||||
|
||||
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
||||
"""Get knowledge space by query"""
|
||||
session = self.get_raw_session()
|
||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||
if query.id is not None:
|
||||
@@ -72,6 +74,10 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.vector_type == query.vector_type
|
||||
)
|
||||
if query.domain_type is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.domain_type == query.domain_type
|
||||
)
|
||||
if query.desc is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.desc == query.desc
|
||||
@@ -88,7 +94,6 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
|
||||
)
|
||||
|
||||
knowledge_spaces = knowledge_spaces.order_by(
|
||||
KnowledgeSpaceEntity.gmt_created.desc()
|
||||
)
|
||||
@@ -134,7 +139,9 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
T: The entity
|
||||
"""
|
||||
request_dict = (
|
||||
request.dict() if isinstance(request, SpaceServeRequest) else request
|
||||
model_to_dict(request)
|
||||
if isinstance(request, SpaceServeRequest)
|
||||
else request
|
||||
)
|
||||
entity = KnowledgeSpaceEntity(**request_dict)
|
||||
return entity
|
||||
@@ -173,4 +180,5 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
desc=entity.desc,
|
||||
owner=entity.owner,
|
||||
context=entity.context,
|
||||
domain_type=entity.domain_type,
|
||||
)
|
||||
|
@@ -171,10 +171,10 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
if value is not None:
|
||||
setattr(entry, key, value)
|
||||
session.merge(entry)
|
||||
res = self.get_one(self.to_request(entry))
|
||||
if not res:
|
||||
raise Exception("Update failed")
|
||||
return res
|
||||
# res = self.get_one(self.to_request(entry))
|
||||
# if not res:
|
||||
# raise Exception("Update failed")
|
||||
return self.to_response(entry)
|
||||
|
||||
def delete(self, query_request: QUERY_SPEC) -> None:
|
||||
"""Delete an entity object.
|
||||
|
@@ -91,18 +91,22 @@ class BasePackage(BaseModel):
|
||||
raise ValueError("The root is required")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
mods = _load_modules_from_file(str(path), name, show_log=False)
|
||||
all_cls = [_get_classes_from_module(m) for m in mods]
|
||||
all_predicate_results = []
|
||||
for m in mods:
|
||||
all_predicate_results.extend(_get_from_module(m, predicates))
|
||||
module_cls = []
|
||||
for list_cls in all_cls:
|
||||
for c in list_cls:
|
||||
if issubclass(c, expected_cls):
|
||||
module_cls.append(c)
|
||||
return module_cls, all_predicate_results, mods
|
||||
try:
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
mods = _load_modules_from_file(str(path), name, show_log=False)
|
||||
all_cls = [_get_classes_from_module(m) for m in mods]
|
||||
all_predicate_results = []
|
||||
for m in mods:
|
||||
all_predicate_results.extend(_get_from_module(m, predicates))
|
||||
module_cls = []
|
||||
for list_cls in all_cls:
|
||||
for c in list_cls:
|
||||
if issubclass(c, expected_cls):
|
||||
module_cls.append(c)
|
||||
return module_cls, all_predicate_results, mods
|
||||
except Exception as e:
|
||||
logger.warning(f"load_module_class error!{str(e)}", e)
|
||||
raise e
|
||||
|
||||
|
||||
class FlowPackage(BasePackage):
|
||||
@@ -316,7 +320,11 @@ def _load_package_from_path(path: str):
|
||||
packages = _load_installed_package(path)
|
||||
parsed_packages = []
|
||||
for package in packages:
|
||||
parsed_packages.append(_parse_package_metadata(package))
|
||||
try:
|
||||
parsed_packages.append(_parse_package_metadata(package))
|
||||
except Exception as e:
|
||||
logger.warning(f"Load package failed!{str(e)}", e)
|
||||
|
||||
return parsed_packages
|
||||
|
||||
|
||||
@@ -405,7 +413,7 @@ class DBGPTsLoader(BaseComponent):
|
||||
self._packages[package.name] = package
|
||||
self._register_packages(package)
|
||||
except Exception as e:
|
||||
logger.warning(f"Load dbgpts package error: {e}")
|
||||
logger.warning(f"Load dbgpts package error: {e}", e)
|
||||
|
||||
def get_flows(self) -> List[FlowPanel]:
|
||||
"""Get the flows.
|
||||
|
Reference in New Issue
Block a user