feat: (0.6)New UI (#1855)

Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
明天
2024-08-21 17:37:45 +08:00
committed by GitHub
parent 3fc82693ba
commit b124ecc10b
824 changed files with 93371 additions and 2515 deletions

View File

@@ -21,6 +21,7 @@ from dbgpt.core import (
)
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
from dbgpt.util.tracer import root_tracer, trace
CFG = Config()
@@ -39,10 +40,7 @@ class ChatKnowledge(BaseChat):
- model_name:(str) llm model name
- select_param:(str) space name
"""
from dbgpt.rag.embedding.embedding_factory import (
EmbeddingFactory,
RerankEmbeddingFactory,
)
from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory
self.knowledge_space = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatKnowledge
@@ -65,20 +63,10 @@ class ChatKnowledge(BaseChat):
if self.space_context is None or self.space_context.get("prompt") is None
else int(self.space_context["prompt"]["max_token"])
)
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.serve.rag.connector import VectorStoreConnector
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
from dbgpt.serve.rag.models.models import (
KnowledgeSpaceDao,
KnowledgeSpaceEntity,
)
from dbgpt.storage.vector_store.base import VectorStoreConfig
spaces = KnowledgeSpaceDao().get_knowledge_space(
KnowledgeSpaceEntity(name=self.knowledge_space)
@@ -87,15 +75,6 @@ class ChatKnowledge(BaseChat):
raise Exception(f"invalid space name:{self.knowledge_space}")
space = spaces[0]
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=self.llm_client,
llm_model=self.llm_model,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=space.vector_type, vector_store_config=config
)
query_rewrite = None
if CFG.KNOWLEDGE_SEARCH_REWRITE:
query_rewrite = QueryRewrite(
@@ -114,18 +93,19 @@ class ChatKnowledge(BaseChat):
# We use reranker, so if the top_k is less than 20,
# we need to set it to 20
retriever_top_k = max(CFG.RERANK_TOP_K, 20)
self.embedding_retriever = EmbeddingRetriever(
self._space_retriever = KnowledgeSpaceRetriever(
space_id=space.id,
top_k=retriever_top_k,
index_store=vector_store_connector.index_client,
query_rewrite=query_rewrite,
rerank=reranker,
)
self.prompt_template.template_is_strict = False
self.relations = None
self.chunk_dao = DocumentChunkDao()
document_dao = KnowledgeDocumentDao()
documents = document_dao.get_documents(
query=KnowledgeDocumentEntity(space=self.knowledge_space)
query=KnowledgeDocumentEntity(space=space.name)
)
if len(documents) > 0:
self.document_ids = [document.id for document in documents]
@@ -251,6 +231,10 @@ class ChatKnowledge(BaseChat):
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value()
def get_space_context_by_id(self, space_id):
service = KnowledgeService()
return service.get_space_context_by_space_id(space_id)
def get_space_context(self, space_name):
service = KnowledgeService()
return service.get_space_context(space_name)
@@ -272,6 +256,6 @@ class ChatKnowledge(BaseChat):
with root_tracer.start_span(
"execute_similar_search", metadata={"query": query}
):
return await self.embedding_retriever.aretrieve_with_scores(
return await self._space_retriever.aretrieve_with_scores(
query, self.recall_score
)