mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 21:12:13 +00:00
refactor:adapt rag storage and add integration documents. (#2361)
This commit is contained in:
@@ -7,12 +7,16 @@ from dbgpt.model.parameter import (
|
||||
ModelServiceConfig,
|
||||
)
|
||||
from dbgpt.storage.cache.manager import ModelCacheParameters
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.util.configure import HookConfig
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.parameter_utils import BaseParameters
|
||||
from dbgpt.util.tracer import TracerParameters
|
||||
from dbgpt.util.utils import LoggingParameters
|
||||
from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteConnectorParameters
|
||||
from dbgpt_ext.storage.knowledge_graph.knowledge_graph import (
|
||||
BuiltinKnowledgeGraphConfig,
|
||||
)
|
||||
from dbgpt_serve.core import BaseServeConfig
|
||||
|
||||
|
||||
@@ -68,14 +72,20 @@ class StorageGraphConfig(BaseParameters):
|
||||
|
||||
@dataclass
|
||||
class StorageConfig(BaseParameters):
|
||||
vector: StorageVectorConfig = field(
|
||||
default_factory=StorageVectorConfig,
|
||||
vector: VectorStoreConfig = field(
|
||||
default_factory=VectorStoreConfig,
|
||||
metadata={
|
||||
"help": _("default vector type"),
|
||||
},
|
||||
)
|
||||
graph: StorageGraphConfig = field(
|
||||
default_factory=StorageGraphConfig,
|
||||
graph: BuiltinKnowledgeGraphConfig = field(
|
||||
default_factory=BuiltinKnowledgeGraphConfig,
|
||||
metadata={
|
||||
"help": _("default graph type"),
|
||||
},
|
||||
)
|
||||
full_text: BuiltinKnowledgeGraphConfig = field(
|
||||
default_factory=BuiltinKnowledgeGraphConfig,
|
||||
metadata={
|
||||
"help": _("default graph type"),
|
||||
},
|
||||
|
@@ -77,6 +77,11 @@ class KnowledgeService:
|
||||
).create()
|
||||
return DefaultLLMClient(worker_manager, True)
|
||||
|
||||
@property
|
||||
def rag_config(self):
|
||||
rag_config = CFG.SYSTEM_APP.config.configs.get("app_config").rag
|
||||
return rag_config
|
||||
|
||||
def create_knowledge_space(self, request: KnowledgeSpaceRequest):
|
||||
"""create knowledge space
|
||||
Args:
|
||||
@@ -86,7 +91,7 @@ class KnowledgeService:
|
||||
name=request.name,
|
||||
)
|
||||
if request.vector_type == "VectorStore":
|
||||
request.vector_type = CFG.VECTOR_STORE_TYPE
|
||||
request.vector_type = self.rag_config.storage.vector.get("type")
|
||||
if request.vector_type == "KnowledgeGraph":
|
||||
knowledge_space_name_pattern = r"^[a-zA-Z0-9\u4e00-\u9fa5]+$"
|
||||
if not re.match(knowledge_space_name_pattern, request.name):
|
||||
|
@@ -9,7 +9,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from dbgpt._private.pydantic import model_to_dict, model_to_json
|
||||
from dbgpt.component import logger
|
||||
from dbgpt.component import SystemApp, logger
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
||||
from dbgpt.core.schema.api import (
|
||||
ChatCompletionResponse,
|
||||
@@ -72,6 +72,7 @@ async def check_api_key(
|
||||
@router.post("/v2/chat/completions", dependencies=[Depends(check_api_key)])
|
||||
async def chat_completions(
|
||||
request: ChatCompletionRequestBody = Body(),
|
||||
service=Depends(get_service),
|
||||
):
|
||||
"""Chat V2 completions
|
||||
Args:
|
||||
@@ -133,7 +134,7 @@ async def chat_completions(
|
||||
span_type=SpanType.CHAT,
|
||||
metadata=model_to_dict(request),
|
||||
):
|
||||
chat: BaseChat = await get_chat_instance(request)
|
||||
chat: BaseChat = await get_chat_instance(request, service.system_app)
|
||||
|
||||
if not request.stream:
|
||||
return await no_stream_wrapper(request, chat)
|
||||
@@ -158,11 +159,14 @@ async def chat_completions(
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_instance(dialogue: ChatCompletionRequestBody = Body()) -> BaseChat:
|
||||
async def get_chat_instance(
|
||||
dialogue: ChatCompletionRequestBody = Body(), system_app: SystemApp = None
|
||||
) -> BaseChat:
|
||||
"""
|
||||
Get chat instance
|
||||
Args:
|
||||
dialogue (OpenAPIChatCompletionRequest): The chat request.
|
||||
system_app (SystemApp): system app.
|
||||
"""
|
||||
logger.info(f"get_chat_instance:{dialogue}")
|
||||
if not dialogue.chat_mode:
|
||||
@@ -191,6 +195,7 @@ async def get_chat_instance(dialogue: ChatCompletionRequestBody = Body()) -> Bas
|
||||
get_executor(),
|
||||
CHAT_FACTORY.get_implementation,
|
||||
dialogue.chat_mode,
|
||||
system_app,
|
||||
**{"chat_param": chat_param},
|
||||
)
|
||||
return chat
|
||||
|
@@ -42,11 +42,20 @@ class ChatKnowledge(BaseChat):
|
||||
self.knowledge_space = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatKnowledge
|
||||
super().__init__(chat_param=chat_param, system_app=system_app)
|
||||
from dbgpt_serve.rag.models.models import (
|
||||
KnowledgeSpaceDao,
|
||||
)
|
||||
|
||||
space_dao = KnowledgeSpaceDao()
|
||||
space = space_dao.get_one({"name": self.knowledge_space})
|
||||
if not space:
|
||||
space = space_dao.get_one({"id": self.knowledge_space})
|
||||
if not space:
|
||||
raise Exception(f"have not found knowledge space:{self.knowledge_space}")
|
||||
self.rag_config = self.app_config.rag
|
||||
self.space_context = self.get_space_context(self.knowledge_space)
|
||||
self.space_context = self.get_space_context(space.name)
|
||||
self.top_k = (
|
||||
self.get_knowledge_search_top_size(self.knowledge_space)
|
||||
self.get_knowledge_search_top_size(space.name)
|
||||
if self.space_context is None
|
||||
else int(self.space_context["embedding"]["topk"])
|
||||
)
|
||||
@@ -55,17 +64,6 @@ class ChatKnowledge(BaseChat):
|
||||
if self.space_context is None
|
||||
else float(self.space_context["embedding"]["recall_score"])
|
||||
)
|
||||
from dbgpt_serve.rag.models.models import (
|
||||
KnowledgeSpaceDao,
|
||||
KnowledgeSpaceEntity,
|
||||
)
|
||||
|
||||
spaces = KnowledgeSpaceDao().get_knowledge_space(
|
||||
KnowledgeSpaceEntity(name=self.knowledge_space)
|
||||
)
|
||||
if len(spaces) != 1:
|
||||
raise Exception(f"invalid space name:{self.knowledge_space}")
|
||||
space = spaces[0]
|
||||
|
||||
query_rewrite = None
|
||||
if self.rag_config.query_rewrite:
|
||||
@@ -230,9 +228,9 @@ class ChatKnowledge(BaseChat):
|
||||
request = KnowledgeSpaceRequest(name=space_name)
|
||||
spaces = service.get_knowledge_space(request)
|
||||
if len(spaces) == 1:
|
||||
from dbgpt_ext.storage import vector_store
|
||||
from dbgpt_ext.storage import __knowledge_graph__ as graph_storages
|
||||
|
||||
if spaces[0].vector_type in vector_store.__knowledge_graph__:
|
||||
if spaces[0].vector_type in graph_storages:
|
||||
return self.rag_config.graph_search_top_k
|
||||
|
||||
return self.rag_config.similarity_top_k
|
||||
|
Reference in New Issue
Block a user