mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 19:40:13 +00:00
✨ feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)
Co-authored-by: Florian <fanzhidongyzby@163.com> Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
This commit is contained in:
@@ -1,12 +1,19 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.configs.model_config import ROOT_PATH
|
||||
from dbgpt.core import Chunk, HumanPromptTemplate, ModelMessage, ModelRequest
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.rag.retriever import RetrieverStrategy
|
||||
from dbgpt.storage.knowledge_graph.community_summary import (
|
||||
CommunitySummaryKnowledgeGraph,
|
||||
CommunitySummaryKnowledgeGraphConfig,
|
||||
)
|
||||
from dbgpt.storage.knowledge_graph.knowledge_graph import (
|
||||
BuiltinKnowledgeGraph,
|
||||
BuiltinKnowledgeGraphConfig,
|
||||
@@ -15,7 +22,7 @@ from dbgpt.storage.knowledge_graph.knowledge_graph import (
|
||||
"""GraphRAG example.
|
||||
pre-requirements:
|
||||
* Set LLM config (url/sk) in `.env`.
|
||||
* Setup/startup TuGraph from: https://github.com/TuGraph-family/tugraph-db
|
||||
* Install pytest utils: `pip install pytest pytest-asyncio`
|
||||
* Config TuGraph following the format below.
|
||||
```
|
||||
GRAPH_STORE_TYPE=TuGraph
|
||||
@@ -24,46 +31,100 @@ from dbgpt.storage.knowledge_graph.knowledge_graph import (
|
||||
TUGRAPH_USERNAME=admin
|
||||
TUGRAPH_PASSWORD=73@TuGraph
|
||||
```
|
||||
|
||||
Examples:
|
||||
..code-block:: shell
|
||||
python examples/rag/graph_rag_example.py
|
||||
pytest -s examples/rag/graph_rag_example.py
|
||||
"""
|
||||
|
||||
llm_client = OpenAILLMClient()
|
||||
model_name = "gpt-4o-mini"
|
||||
|
||||
def _create_kg_connector():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_naive_graph_rag():
|
||||
await __run_graph_rag(
|
||||
knowledge_file="examples/test_files/graphrag-mini.md",
|
||||
chunk_strategy="CHUNK_BY_SIZE",
|
||||
knowledge_graph=__create_naive_kg_connector(),
|
||||
question="What's the relationship between TuGraph and DB-GPT ?",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_community_graph_rag():
|
||||
await __run_graph_rag(
|
||||
knowledge_file="examples/test_files/graphrag-mini.md",
|
||||
chunk_strategy="CHUNK_BY_MARKDOWN_HEADER",
|
||||
knowledge_graph=__create_community_kg_connector(),
|
||||
question="What's the relationship between TuGraph and DB-GPT ?",
|
||||
)
|
||||
|
||||
|
||||
def __create_naive_kg_connector():
|
||||
"""Create knowledge graph connector."""
|
||||
return BuiltinKnowledgeGraph(
|
||||
config=BuiltinKnowledgeGraphConfig(
|
||||
name="graph_rag_test",
|
||||
name="naive_graph_rag_test",
|
||||
embedding_fn=None,
|
||||
llm_client=OpenAILLMClient(),
|
||||
model_name="gpt-3.5-turbo",
|
||||
llm_client=llm_client,
|
||||
model_name=model_name,
|
||||
graph_store_type="MemoryGraph",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "examples/test_files/tranformers_story.md")
|
||||
def __create_community_kg_connector():
|
||||
"""Create community knowledge graph connector."""
|
||||
return CommunitySummaryKnowledgeGraph(
|
||||
config=CommunitySummaryKnowledgeGraphConfig(
|
||||
name="community_graph_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory.openai(),
|
||||
llm_client=llm_client,
|
||||
model_name=model_name,
|
||||
graph_store_type="TuGraphGraph",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def ask_chunk(chunk: Chunk, question) -> str:
|
||||
rag_template = (
|
||||
"Based on the following [Context] {context}, " "answer [Question] {question}."
|
||||
)
|
||||
template = HumanPromptTemplate.from_template(rag_template)
|
||||
messages = template.format_messages(context=chunk.content, question=question)
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
request = ModelRequest(model=model_name, messages=model_messages)
|
||||
response = await llm_client.generate(request=request)
|
||||
|
||||
if not response.success:
|
||||
code = str(response.error_code)
|
||||
reason = response.text
|
||||
raise Exception(f"request llm failed ({code}) {reason}")
|
||||
|
||||
return response.text
|
||||
|
||||
|
||||
async def __run_graph_rag(knowledge_file, chunk_strategy, knowledge_graph, question):
|
||||
file_path = os.path.join(ROOT_PATH, knowledge_file).format()
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
graph_store = _create_kg_connector()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
# get embedding assembler
|
||||
assembler = await EmbeddingAssembler.aload_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
index_store=graph_store,
|
||||
retrieve_strategy=RetrieverStrategy.GRAPH,
|
||||
)
|
||||
await assembler.apersist()
|
||||
# get embeddings retriever
|
||||
retriever = assembler.as_retriever(3)
|
||||
chunks = await retriever.aretrieve_with_scores(
|
||||
"What actions has Megatron taken ?", score_threshold=0.3
|
||||
)
|
||||
print(f"embedding rag example results:{chunks}")
|
||||
graph_store.delete_vector_name("graph_rag_test")
|
||||
try:
|
||||
chunk_parameters = ChunkParameters(chunk_strategy=chunk_strategy)
|
||||
|
||||
# get embedding assembler
|
||||
assembler = await EmbeddingAssembler.aload_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
index_store=knowledge_graph,
|
||||
retrieve_strategy=RetrieverStrategy.GRAPH,
|
||||
)
|
||||
await assembler.apersist()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
# get embeddings retriever
|
||||
retriever = assembler.as_retriever(1)
|
||||
chunks = await retriever.aretrieve_with_scores(question, score_threshold=0.3)
|
||||
|
||||
# chat
|
||||
print(f"{await ask_chunk(chunks[0], question)}")
|
||||
|
||||
finally:
|
||||
knowledge_graph.delete_vector_name(knowledge_graph.get_config().name)
|
||||
|
Reference in New Issue
Block a user