feat:knowledge rag graph

This commit is contained in:
aries_ckt 2023-10-13 14:22:46 +08:00
parent 8a5e35c5f2
commit fa6a9040d5
5 changed files with 89 additions and 1 deletions

View File

@ -1,6 +1,7 @@
import os
from typing import Dict
from pilot.component import ComponentType
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
@ -47,6 +48,11 @@ class ChatKnowledge(BaseChat):
"vector_store_name": self.knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
from pilot.graph_engine.graph_factory import RAGGraphFactory
self.rag_engine = CFG.SYSTEM_APP.get_component(
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
).create()
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
@ -82,6 +88,25 @@ class ChatKnowledge(BaseChat):
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"]
# docs = self.rag_engine.search(query=self.current_user_input)
# import httpx
# with httpx.Client() as client:
# request = client.build_request(
# "post",
# "http://127.0.0.1/api/knowledge/entities/extract",
# json="application/json", # using json for data to ensure it sends as application/json
# params={"text": self.current_user_input},
# headers={},
# )
#
# response = client.send(request)
# if response.status_code != 200:
# error_msg = f"request /api/knowledge/entities/extract failed, error: {response.text}"
# raise Exception(error_msg)
# docs = response.json()
# import requests
# docs = requests.post("http://127.0.0.1:5000/api/knowledge/entities/extract", headers={}, json={"text": self.current_user_input})
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, self.top_k
)

View File

@ -31,6 +31,7 @@ def initialize_components(
# Register global default RAGGraphFactory
from pilot.graph_engine.graph_factory import DefaultRAGGraphFactory
system_app.register(DefaultRAGGraphFactory)
_initialize_embedding_model(

View File

@ -24,6 +24,7 @@ from pilot.server.knowledge.request.request import (
ChunkQueryRequest,
DocumentQueryRequest,
SpaceArgumentRequest,
EntityExtractRequest,
)
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
@ -198,3 +199,37 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
for d in docs
]
return {"response": res}
@router.post("/knowledge/entity/extract")
async def entity_extract(request: EntityExtractRequest):
logger.info(f"Received params: {request}")
try:
# from pilot.graph_engine.graph_factory import RAGGraphFactory
# from pilot.component import ComponentType
# rag_engine = CFG.SYSTEM_APP.get_component(
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
# ).create()
# return Result.succ(await rag_engine.search(request.text))
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": request.text,
"select_param": "entity",
"model_name": request.model_name,
}
# import nest_asyncio
# nest_asyncio.apply()
# loop = asyncio.get_event_loop()
# loop.stop()
# loop = utils.get_or_create_event_loop()
res = await llm_chat_response_nostream(
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
)
return Result.succ(res)
except Exception as e:
return Result.faild(code="E000X", msg=f"entity extract error {e}")

View File

@ -104,3 +104,10 @@ class SpaceArgumentRequest(BaseModel):
"""argument: argument"""
argument: str
class EntityExtractRequest(BaseModel):
"""argument: argument"""
text: str
model_name: str

View File

@ -58,7 +58,11 @@ class SyncStatus(Enum):
# @singleton
class KnowledgeService:
def __init__(self):
pass
from pilot.graph_engine.graph_engine import RAGGraphEngine
# source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt"
# pass
"""create knowledge space"""
@ -229,6 +233,10 @@ class KnowledgeService:
pre_separator=sync_request.pre_separator,
text_splitter_impl=text_splitter,
)
from pilot.graph_engine.graph_engine import RAGGraphEngine
# source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt"
# engine = RAGGraphEngine(knowledge_source=source, model_name="proxyllm", text_splitter=text_splitter)
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
@ -244,6 +252,18 @@ class KnowledgeService:
embedding_factory=embedding_factory,
)
chunk_docs = client.read()
from pilot.graph_engine.graph_factory import RAGGraphFactory
rag_engine = CFG.SYSTEM_APP.get_component(
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
).create()
rag_engine.knowledge_graph(docs=chunk_docs)
# docs = engine.search(
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
# )
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)