mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 04:44:14 +00:00
feat:knowledge rag graph
This commit is contained in:
parent
8a5e35c5f2
commit
fa6a9040d5
@ -1,6 +1,7 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from pilot.component import ComponentType
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
@ -47,6 +48,11 @@ class ChatKnowledge(BaseChat):
|
||||
"vector_store_name": self.knowledge_space,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||
|
||||
self.rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
).create()
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
@ -82,6 +88,25 @@ class ChatKnowledge(BaseChat):
|
||||
if self.space_context:
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
# docs = self.rag_engine.search(query=self.current_user_input)
|
||||
# import httpx
|
||||
# with httpx.Client() as client:
|
||||
# request = client.build_request(
|
||||
# "post",
|
||||
# "http://127.0.0.1/api/knowledge/entities/extract",
|
||||
# json="application/json", # using json for data to ensure it sends as application/json
|
||||
# params={"text": self.current_user_input},
|
||||
# headers={},
|
||||
# )
|
||||
#
|
||||
# response = client.send(request)
|
||||
# if response.status_code != 200:
|
||||
# error_msg = f"request /api/knowledge/entities/extract failed, error: {response.text}"
|
||||
# raise Exception(error_msg)
|
||||
# docs = response.json()
|
||||
# import requests
|
||||
# docs = requests.post("http://127.0.0.1:5000/api/knowledge/entities/extract", headers={}, json={"text": self.current_user_input})
|
||||
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, self.top_k
|
||||
)
|
||||
|
@ -31,6 +31,7 @@ def initialize_components(
|
||||
|
||||
# Register global default RAGGraphFactory
|
||||
from pilot.graph_engine.graph_factory import DefaultRAGGraphFactory
|
||||
|
||||
system_app.register(DefaultRAGGraphFactory)
|
||||
|
||||
_initialize_embedding_model(
|
||||
|
@ -24,6 +24,7 @@ from pilot.server.knowledge.request.request import (
|
||||
ChunkQueryRequest,
|
||||
DocumentQueryRequest,
|
||||
SpaceArgumentRequest,
|
||||
EntityExtractRequest,
|
||||
)
|
||||
|
||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||
@ -198,3 +199,37 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
||||
for d in docs
|
||||
]
|
||||
return {"response": res}
|
||||
|
||||
|
||||
@router.post("/knowledge/entity/extract")
|
||||
async def entity_extract(request: EntityExtractRequest):
|
||||
logger.info(f"Received params: {request}")
|
||||
try:
|
||||
# from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||
# from pilot.component import ComponentType
|
||||
# rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
# ).create()
|
||||
# return Result.succ(await rag_engine.search(request.text))
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": request.text,
|
||||
"select_param": "entity",
|
||||
"model_name": request.model_name,
|
||||
}
|
||||
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# loop = asyncio.get_event_loop()
|
||||
# loop.stop()
|
||||
# loop = utils.get_or_create_event_loop()
|
||||
res = await llm_chat_response_nostream(
|
||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
return Result.succ(res)
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"entity extract error {e}")
|
||||
|
@ -104,3 +104,10 @@ class SpaceArgumentRequest(BaseModel):
|
||||
"""argument: argument"""
|
||||
|
||||
argument: str
|
||||
|
||||
|
||||
class EntityExtractRequest(BaseModel):
|
||||
"""argument: argument"""
|
||||
|
||||
text: str
|
||||
model_name: str
|
||||
|
@ -58,7 +58,11 @@ class SyncStatus(Enum):
|
||||
# @singleton
|
||||
class KnowledgeService:
|
||||
def __init__(self):
|
||||
pass
|
||||
from pilot.graph_engine.graph_engine import RAGGraphEngine
|
||||
|
||||
# source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt"
|
||||
|
||||
# pass
|
||||
|
||||
"""create knowledge space"""
|
||||
|
||||
@ -229,6 +233,10 @@ class KnowledgeService:
|
||||
pre_separator=sync_request.pre_separator,
|
||||
text_splitter_impl=text_splitter,
|
||||
)
|
||||
from pilot.graph_engine.graph_engine import RAGGraphEngine
|
||||
|
||||
# source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt"
|
||||
# engine = RAGGraphEngine(knowledge_source=source, model_name="proxyllm", text_splitter=text_splitter)
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
@ -244,6 +252,18 @@ class KnowledgeService:
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
chunk_docs = client.read()
|
||||
from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||
|
||||
rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
).create()
|
||||
rag_engine.knowledge_graph(docs=chunk_docs)
|
||||
# docs = engine.search(
|
||||
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
|
||||
# )
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
# update document status
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
|
Loading…
Reference in New Issue
Block a user