From fa6a9040d5e0670ba491b5574101e5de3e8f046a Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Fri, 13 Oct 2023 14:22:46 +0800 Subject: [PATCH] feat:knowledge rag graph --- pilot/scene/chat_knowledge/v1/chat.py | 25 ++++++++++++++++ pilot/server/component_configs.py | 1 + pilot/server/knowledge/api.py | 35 +++++++++++++++++++++++ pilot/server/knowledge/request/request.py | 7 +++++ pilot/server/knowledge/service.py | 22 +++++++++++++- 5 files changed, 89 insertions(+), 1 deletion(-) diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 8177a1a5a..ebecddd19 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -1,6 +1,7 @@ import os from typing import Dict +from pilot.component import ComponentType from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene from pilot.configs.config import Config @@ -47,6 +48,11 @@ class ChatKnowledge(BaseChat): "vector_store_name": self.knowledge_space, "vector_store_type": CFG.VECTOR_STORE_TYPE, } + from pilot.graph_engine.graph_factory import RAGGraphFactory + + self.rag_engine = CFG.SYSTEM_APP.get_component( + ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory + ).create() embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) @@ -82,6 +88,25 @@ class ChatKnowledge(BaseChat): if self.space_context: self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template = self.space_context["prompt"]["template"] + # docs = self.rag_engine.search(query=self.current_user_input) + # import httpx + # with httpx.Client() as client: + # request = client.build_request( + # "post", + # "http://127.0.0.1/api/knowledge/entities/extract", + # json="application/json", # using json for data to ensure it sends as application/json + # params={"text": self.current_user_input}, + # headers={}, + # ) + # + # response = client.send(request) + # if response.status_code != 200: + # error_msg = f"request /api/knowledge/entities/extract failed, error: {response.text}" + # raise Exception(error_msg) + # docs = response.json() + # import requests + # docs = requests.post("http://127.0.0.1:5000/api/knowledge/entities/extract", headers={}, json={"text": self.current_user_input}) + docs = self.knowledge_embedding_client.similar_search( self.current_user_input, self.top_k ) diff --git a/pilot/server/component_configs.py b/pilot/server/component_configs.py index 7d306ada1..ba5c35ec6 100644 --- a/pilot/server/component_configs.py +++ b/pilot/server/component_configs.py @@ -31,6 +31,7 @@ def initialize_components( # Register global default RAGGraphFactory from pilot.graph_engine.graph_factory import DefaultRAGGraphFactory + system_app.register(DefaultRAGGraphFactory) _initialize_embedding_model( diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index 57fadb21e..e0f31031e 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -24,6 +24,7 @@ from pilot.server.knowledge.request.request import ( ChunkQueryRequest, DocumentQueryRequest, SpaceArgumentRequest, + EntityExtractRequest, ) from pilot.server.knowledge.request.request import KnowledgeSpaceRequest @@ -198,3 +199,37 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest): for d in docs ] return {"response": res} + + +@router.post("/knowledge/entity/extract") +async def entity_extract(request: EntityExtractRequest): + logger.info(f"Received params: {request}") + try: + # from pilot.graph_engine.graph_factory import RAGGraphFactory + # from pilot.component import ComponentType + # rag_engine = CFG.SYSTEM_APP.get_component( + # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory + # ).create() + # return Result.succ(await rag_engine.search(request.text)) + from pilot.scene.base import ChatScene + from pilot.common.chat_util import llm_chat_response_nostream + import uuid + + chat_param = { + "chat_session_id": uuid.uuid1(), + "current_user_input": request.text, + "select_param": "entity", + "model_name": request.model_name, + } + + # import nest_asyncio + # nest_asyncio.apply() + # loop = asyncio.get_event_loop() + # loop.stop() + # loop = utils.get_or_create_event_loop() + res = await llm_chat_response_nostream( + ChatScene.ExtractEntity.value(), **{"chat_param": chat_param} + ) + return Result.succ(res) + except Exception as e: + return Result.faild(code="E000X", msg=f"entity extract error {e}") diff --git a/pilot/server/knowledge/request/request.py b/pilot/server/knowledge/request/request.py index b83165c19..c6b94ff0d 100644 --- a/pilot/server/knowledge/request/request.py +++ b/pilot/server/knowledge/request/request.py @@ -104,3 +104,10 @@ class SpaceArgumentRequest(BaseModel): """argument: argument""" argument: str + + +class EntityExtractRequest(BaseModel): + """argument: argument""" + + text: str + model_name: str diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index c11fc3b46..f4150fa73 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -58,7 +58,11 @@ class SyncStatus(Enum): # @singleton class KnowledgeService: def __init__(self): - pass + from pilot.graph_engine.graph_engine import RAGGraphEngine + + # source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt" + + # pass """create knowledge space""" @@ -229,6 +233,10 @@ class KnowledgeService: pre_separator=sync_request.pre_separator, text_splitter_impl=text_splitter, ) + from pilot.graph_engine.graph_engine import RAGGraphEngine + + # source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt" + # engine = RAGGraphEngine(knowledge_source=source, model_name="proxyllm", text_splitter=text_splitter) embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) @@ -244,6 +252,18 @@ class KnowledgeService: embedding_factory=embedding_factory, ) chunk_docs = client.read() + from pilot.graph_engine.graph_factory import RAGGraphFactory + + rag_engine = CFG.SYSTEM_APP.get_component( + ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory + ).create() + rag_engine.knowledge_graph(docs=chunk_docs) + # docs = engine.search( + # "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA" + # ) + embedding_factory = CFG.SYSTEM_APP.get_component( + "embedding_factory", EmbeddingFactory + ) # update document status doc.status = SyncStatus.RUNNING.name doc.chunk_size = len(chunk_docs)