mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 04:44:14 +00:00
feat:knowledge rag graph
This commit is contained in:
parent
8a5e35c5f2
commit
fa6a9040d5
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
from pilot.component import ComponentType
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -47,6 +48,11 @@ class ChatKnowledge(BaseChat):
|
|||||||
"vector_store_name": self.knowledge_space,
|
"vector_store_name": self.knowledge_space,
|
||||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||||
}
|
}
|
||||||
|
from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||||
|
|
||||||
|
self.rag_engine = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||||
|
).create()
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
)
|
)
|
||||||
@ -82,6 +88,25 @@ class ChatKnowledge(BaseChat):
|
|||||||
if self.space_context:
|
if self.space_context:
|
||||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||||
|
# docs = self.rag_engine.search(query=self.current_user_input)
|
||||||
|
# import httpx
|
||||||
|
# with httpx.Client() as client:
|
||||||
|
# request = client.build_request(
|
||||||
|
# "post",
|
||||||
|
# "http://127.0.0.1/api/knowledge/entities/extract",
|
||||||
|
# json="application/json", # using json for data to ensure it sends as application/json
|
||||||
|
# params={"text": self.current_user_input},
|
||||||
|
# headers={},
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# response = client.send(request)
|
||||||
|
# if response.status_code != 200:
|
||||||
|
# error_msg = f"request /api/knowledge/entities/extract failed, error: {response.text}"
|
||||||
|
# raise Exception(error_msg)
|
||||||
|
# docs = response.json()
|
||||||
|
# import requests
|
||||||
|
# docs = requests.post("http://127.0.0.1:5000/api/knowledge/entities/extract", headers={}, json={"text": self.current_user_input})
|
||||||
|
|
||||||
docs = self.knowledge_embedding_client.similar_search(
|
docs = self.knowledge_embedding_client.similar_search(
|
||||||
self.current_user_input, self.top_k
|
self.current_user_input, self.top_k
|
||||||
)
|
)
|
||||||
|
@ -31,6 +31,7 @@ def initialize_components(
|
|||||||
|
|
||||||
# Register global default RAGGraphFactory
|
# Register global default RAGGraphFactory
|
||||||
from pilot.graph_engine.graph_factory import DefaultRAGGraphFactory
|
from pilot.graph_engine.graph_factory import DefaultRAGGraphFactory
|
||||||
|
|
||||||
system_app.register(DefaultRAGGraphFactory)
|
system_app.register(DefaultRAGGraphFactory)
|
||||||
|
|
||||||
_initialize_embedding_model(
|
_initialize_embedding_model(
|
||||||
|
@ -24,6 +24,7 @@ from pilot.server.knowledge.request.request import (
|
|||||||
ChunkQueryRequest,
|
ChunkQueryRequest,
|
||||||
DocumentQueryRequest,
|
DocumentQueryRequest,
|
||||||
SpaceArgumentRequest,
|
SpaceArgumentRequest,
|
||||||
|
EntityExtractRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||||
@ -198,3 +199,37 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
|||||||
for d in docs
|
for d in docs
|
||||||
]
|
]
|
||||||
return {"response": res}
|
return {"response": res}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/knowledge/entity/extract")
|
||||||
|
async def entity_extract(request: EntityExtractRequest):
|
||||||
|
logger.info(f"Received params: {request}")
|
||||||
|
try:
|
||||||
|
# from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||||
|
# from pilot.component import ComponentType
|
||||||
|
# rag_engine = CFG.SYSTEM_APP.get_component(
|
||||||
|
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||||
|
# ).create()
|
||||||
|
# return Result.succ(await rag_engine.search(request.text))
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.chat_util import llm_chat_response_nostream
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
chat_param = {
|
||||||
|
"chat_session_id": uuid.uuid1(),
|
||||||
|
"current_user_input": request.text,
|
||||||
|
"select_param": "entity",
|
||||||
|
"model_name": request.model_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
# import nest_asyncio
|
||||||
|
# nest_asyncio.apply()
|
||||||
|
# loop = asyncio.get_event_loop()
|
||||||
|
# loop.stop()
|
||||||
|
# loop = utils.get_or_create_event_loop()
|
||||||
|
res = await llm_chat_response_nostream(
|
||||||
|
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||||
|
)
|
||||||
|
return Result.succ(res)
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E000X", msg=f"entity extract error {e}")
|
||||||
|
@ -104,3 +104,10 @@ class SpaceArgumentRequest(BaseModel):
|
|||||||
"""argument: argument"""
|
"""argument: argument"""
|
||||||
|
|
||||||
argument: str
|
argument: str
|
||||||
|
|
||||||
|
|
||||||
|
class EntityExtractRequest(BaseModel):
|
||||||
|
"""argument: argument"""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
model_name: str
|
||||||
|
@ -58,7 +58,11 @@ class SyncStatus(Enum):
|
|||||||
# @singleton
|
# @singleton
|
||||||
class KnowledgeService:
|
class KnowledgeService:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
from pilot.graph_engine.graph_engine import RAGGraphEngine
|
||||||
|
|
||||||
|
# source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt"
|
||||||
|
|
||||||
|
# pass
|
||||||
|
|
||||||
"""create knowledge space"""
|
"""create knowledge space"""
|
||||||
|
|
||||||
@ -229,6 +233,10 @@ class KnowledgeService:
|
|||||||
pre_separator=sync_request.pre_separator,
|
pre_separator=sync_request.pre_separator,
|
||||||
text_splitter_impl=text_splitter,
|
text_splitter_impl=text_splitter,
|
||||||
)
|
)
|
||||||
|
from pilot.graph_engine.graph_engine import RAGGraphEngine
|
||||||
|
|
||||||
|
# source = "/Users/chenketing/Desktop/project/llama_index/examples/paul_graham_essay/data/test/test_kg_text.txt"
|
||||||
|
# engine = RAGGraphEngine(knowledge_source=source, model_name="proxyllm", text_splitter=text_splitter)
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
)
|
)
|
||||||
@ -244,6 +252,18 @@ class KnowledgeService:
|
|||||||
embedding_factory=embedding_factory,
|
embedding_factory=embedding_factory,
|
||||||
)
|
)
|
||||||
chunk_docs = client.read()
|
chunk_docs = client.read()
|
||||||
|
from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||||
|
|
||||||
|
rag_engine = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||||
|
).create()
|
||||||
|
rag_engine.knowledge_graph(docs=chunk_docs)
|
||||||
|
# docs = engine.search(
|
||||||
|
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
|
||||||
|
# )
|
||||||
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
|
"embedding_factory", EmbeddingFactory
|
||||||
|
)
|
||||||
# update document status
|
# update document status
|
||||||
doc.status = SyncStatus.RUNNING.name
|
doc.status = SyncStatus.RUNNING.name
|
||||||
doc.chunk_size = len(chunk_docs)
|
doc.chunk_size = len(chunk_docs)
|
||||||
|
Loading…
Reference in New Issue
Block a user