DB-GPT/examples/rag/graph_rag_example.py
2025-03-17 14:15:21 +08:00

124 lines
4.0 KiB
Python

import os
import pytest
from dbgpt.configs.model_config import ROOT_PATH
from dbgpt.core import Chunk, HumanPromptTemplate, ModelMessage, ModelRequest
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.retriever import RetrieverStrategy
from dbgpt_ext.rag import ChunkParameters
from dbgpt_ext.rag.assembler import EmbeddingAssembler
from dbgpt_ext.rag.knowledge import KnowledgeFactory
from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig
from dbgpt_ext.storage.knowledge_graph.community_summary import (
CommunitySummaryKnowledgeGraph,
)
from dbgpt_ext.storage.knowledge_graph.knowledge_graph import (
BuiltinKnowledgeGraph,
)
"""GraphRAG example.
```
# Set LLM config (url/sk) in `.env`.
# Install pytest utils: `pip install pytest pytest-asyncio`
GRAPH_STORE_TYPE=TuGraph
TUGRAPH_HOST=127.0.0.1
TUGRAPH_PORT=7687
TUGRAPH_USERNAME=admin
TUGRAPH_PASSWORD=73@TuGraph
```
Examples:
..code-block:: shell
pytest -s examples/rag/graph_rag_example.py
"""
llm_client = OpenAILLMClient()
model_name = "gpt-4o-mini"
@pytest.mark.asyncio
async def test_naive_graph_rag():
await __run_graph_rag(
knowledge_file="examples/test_files/graphrag-mini.md",
chunk_strategy="CHUNK_BY_SIZE",
knowledge_graph=__create_naive_kg_connector(),
question="What's the relationship between TuGraph and DB-GPT ?",
)
@pytest.mark.asyncio
async def test_community_graph_rag():
await __run_graph_rag(
knowledge_file="examples/test_files/graphrag-mini.md",
chunk_strategy="CHUNK_BY_MARKDOWN_HEADER",
knowledge_graph=__create_community_kg_connector(),
question="What's the relationship between TuGraph and DB-GPT ?",
)
def __create_naive_kg_connector():
"""Create knowledge graph connector."""
return BuiltinKnowledgeGraph(
config=TuGraphStoreConfig(),
name="naive_graph_rag_test",
embedding_fn=None,
llm_client=llm_client,
llm_model=model_name,
)
def __create_community_kg_connector():
"""Create community knowledge graph connector."""
return CommunitySummaryKnowledgeGraph(
config=TuGraphStoreConfig(),
name="community_graph_rag_test",
embedding_fn=DefaultEmbeddingFactory.openai(),
llm_client=llm_client,
llm_model=model_name,
)
async def ask_chunk(chunk: Chunk, question) -> str:
rag_template = (
"Based on the following [Context] {context}, answer [Question] {question}."
)
template = HumanPromptTemplate.from_template(rag_template)
messages = template.format_messages(context=chunk.content, question=question)
model_messages = ModelMessage.from_base_messages(messages)
request = ModelRequest(model=model_name, messages=model_messages)
response = await llm_client.generate(request=request)
if not response.success:
code = str(response.error_code)
reason = response.text
raise Exception(f"request llm failed ({code}) {reason}")
return response.text
async def __run_graph_rag(knowledge_file, chunk_strategy, knowledge_graph, question):
file_path = os.path.join(ROOT_PATH, knowledge_file).format()
knowledge = KnowledgeFactory.from_file_path(file_path)
try:
chunk_parameters = ChunkParameters(chunk_strategy=chunk_strategy)
# get embedding assembler
assembler = await EmbeddingAssembler.aload_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
index_store=knowledge_graph,
retrieve_strategy=RetrieverStrategy.GRAPH,
)
await assembler.apersist()
# get embeddings retriever
retriever = assembler.as_retriever(1)
chunks = await retriever.aretrieve_with_scores(question, score_threshold=0.3)
# chat
print(f"{await ask_chunk(chunks[0], question)}")
finally:
knowledge_graph.delete_vector_name(knowledge_graph.get_config().name)