mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
feat:knowledge rag graph
This commit is contained in:
parent
813e2260a6
commit
fc656e1c61
@ -41,7 +41,11 @@ class KnowledgeType(Enum):
|
||||
|
||||
|
||||
def get_knowledge_embedding(
|
||||
knowledge_type, knowledge_source, vector_store_config, source_reader, text_splitter
|
||||
knowledge_type,
|
||||
knowledge_source,
|
||||
vector_store_config=None,
|
||||
source_reader=None,
|
||||
text_splitter=None,
|
||||
):
|
||||
match knowledge_type:
|
||||
case KnowledgeType.DOCUMENT.value:
|
||||
|
@ -31,11 +31,11 @@ class SourceEmbedding(ABC):
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config
|
||||
self.vector_store_config = vector_store_config or {}
|
||||
self.source_reader = source_reader or None
|
||||
self.text_splitter = text_splitter or None
|
||||
self.embedding_args = embedding_args
|
||||
self.embeddings = vector_store_config["embeddings"]
|
||||
self.embeddings = self.vector_store_config.get("embeddings", None)
|
||||
|
||||
@abstractmethod
|
||||
@register
|
||||
|
@ -27,7 +27,7 @@ class URLEmbedding(SourceEmbedding):
|
||||
file_path, vector_store_config, source_reader=None, text_splitter=None
|
||||
)
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config
|
||||
self.vector_store_config = vector_store_config or None
|
||||
self.source_reader = source_reader or None
|
||||
self.text_splitter = text_splitter or None
|
||||
|
||||
|
@ -75,6 +75,16 @@ class ChatScene(Enum):
|
||||
"Dialogue through natural language and private documents and knowledge bases.",
|
||||
["Knowledge Space Select"],
|
||||
)
|
||||
ExtractTriplet = Scene(
|
||||
"extract_triplet",
|
||||
"Extract Triplet",
|
||||
"Extract Triplet",
|
||||
["Extract Select"],
|
||||
True,
|
||||
)
|
||||
ExtractEntity = Scene(
|
||||
"extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def of_mode(mode):
|
||||
|
@ -205,6 +205,40 @@ class BaseChat(ABC):
|
||||
self.memory.append(self.current_message)
|
||||
return self.current_ai_response()
|
||||
|
||||
async def get_llm_response(self):
|
||||
payload = self.__call_base()
|
||||
logger.info(f"Request: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
|
||||
model_output = await worker_manager.generate(payload)
|
||||
|
||||
### output parse
|
||||
ai_response_text = (
|
||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||
model_output, self.prompt_template.sep
|
||||
)
|
||||
)
|
||||
### model result deal
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
ai_response_text
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error("model response parse failed!" + str(e))
|
||||
self.current_message.add_view_message(
|
||||
f"""model response parse failed!{str(e)}\n {ai_response_text} """
|
||||
)
|
||||
return prompt_define_response
|
||||
|
||||
def _blocking_stream_call(self):
|
||||
logger.warn(
|
||||
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"
|
||||
|
@ -13,6 +13,8 @@ class ChatFactory(metaclass=Singleton):
|
||||
from pilot.scene.chat_dashboard.chat import ChatDashboard
|
||||
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
|
||||
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
||||
from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
|
||||
from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity
|
||||
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
||||
|
||||
chat_classes = BaseChat.__subclasses__()
|
||||
|
@ -168,10 +168,12 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
||||
assert loop is not None
|
||||
return loop
|
||||
except RuntimeError as e:
|
||||
if not "no running event loop" in str(e):
|
||||
if not "no running event loop" in str(e) and not "no current event loop" in str(
|
||||
e
|
||||
):
|
||||
raise e
|
||||
logging.warning("Cant not get running event loop, create new event loop now")
|
||||
return asyncio.get_event_loop_policy().get_event_loop()
|
||||
return asyncio.get_event_loop_policy().new_event_loop()
|
||||
|
||||
|
||||
def logging_str_to_uvicorn_level(log_level_str):
|
||||
|
Loading…
Reference in New Issue
Block a user