mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 13:57:46 +00:00
124 lines
4.0 KiB
Python
124 lines
4.0 KiB
Python
import os
|
|
|
|
import pytest
|
|
|
|
from dbgpt.configs.model_config import ROOT_PATH
|
|
from dbgpt.core import Chunk, HumanPromptTemplate, ModelMessage, ModelRequest
|
|
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
|
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
|
from dbgpt.rag.retriever import RetrieverStrategy
|
|
from dbgpt_ext.rag import ChunkParameters
|
|
from dbgpt_ext.rag.assembler import EmbeddingAssembler
|
|
from dbgpt_ext.rag.knowledge import KnowledgeFactory
|
|
from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig
|
|
from dbgpt_ext.storage.knowledge_graph.community_summary import (
|
|
CommunitySummaryKnowledgeGraph,
|
|
)
|
|
from dbgpt_ext.storage.knowledge_graph.knowledge_graph import (
|
|
BuiltinKnowledgeGraph,
|
|
)
|
|
|
|
"""GraphRAG example.
|
|
```
|
|
# Set LLM config (url/sk) in `.env`.
|
|
# Install pytest utils: `pip install pytest pytest-asyncio`
|
|
GRAPH_STORE_TYPE=TuGraph
|
|
TUGRAPH_HOST=127.0.0.1
|
|
TUGRAPH_PORT=7687
|
|
TUGRAPH_USERNAME=admin
|
|
TUGRAPH_PASSWORD=73@TuGraph
|
|
```
|
|
Examples:
|
|
..code-block:: shell
|
|
pytest -s examples/rag/graph_rag_example.py
|
|
"""
|
|
|
|
llm_client = OpenAILLMClient()
|
|
model_name = "gpt-4o-mini"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_naive_graph_rag():
|
|
await __run_graph_rag(
|
|
knowledge_file="examples/test_files/graphrag-mini.md",
|
|
chunk_strategy="CHUNK_BY_SIZE",
|
|
knowledge_graph=__create_naive_kg_connector(),
|
|
question="What's the relationship between TuGraph and DB-GPT ?",
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_community_graph_rag():
|
|
await __run_graph_rag(
|
|
knowledge_file="examples/test_files/graphrag-mini.md",
|
|
chunk_strategy="CHUNK_BY_MARKDOWN_HEADER",
|
|
knowledge_graph=__create_community_kg_connector(),
|
|
question="What's the relationship between TuGraph and DB-GPT ?",
|
|
)
|
|
|
|
|
|
def __create_naive_kg_connector():
|
|
"""Create knowledge graph connector."""
|
|
return BuiltinKnowledgeGraph(
|
|
config=TuGraphStoreConfig(),
|
|
name="naive_graph_rag_test",
|
|
embedding_fn=None,
|
|
llm_client=llm_client,
|
|
llm_model=model_name,
|
|
)
|
|
|
|
|
|
def __create_community_kg_connector():
|
|
"""Create community knowledge graph connector."""
|
|
return CommunitySummaryKnowledgeGraph(
|
|
config=TuGraphStoreConfig(),
|
|
name="community_graph_rag_test",
|
|
embedding_fn=DefaultEmbeddingFactory.openai(),
|
|
llm_client=llm_client,
|
|
llm_model=model_name,
|
|
)
|
|
|
|
|
|
async def ask_chunk(chunk: Chunk, question) -> str:
|
|
rag_template = (
|
|
"Based on the following [Context] {context}, answer [Question] {question}."
|
|
)
|
|
template = HumanPromptTemplate.from_template(rag_template)
|
|
messages = template.format_messages(context=chunk.content, question=question)
|
|
model_messages = ModelMessage.from_base_messages(messages)
|
|
request = ModelRequest(model=model_name, messages=model_messages)
|
|
response = await llm_client.generate(request=request)
|
|
|
|
if not response.success:
|
|
code = str(response.error_code)
|
|
reason = response.text
|
|
raise Exception(f"request llm failed ({code}) {reason}")
|
|
|
|
return response.text
|
|
|
|
|
|
async def __run_graph_rag(knowledge_file, chunk_strategy, knowledge_graph, question):
|
|
file_path = os.path.join(ROOT_PATH, knowledge_file).format()
|
|
knowledge = KnowledgeFactory.from_file_path(file_path)
|
|
try:
|
|
chunk_parameters = ChunkParameters(chunk_strategy=chunk_strategy)
|
|
|
|
# get embedding assembler
|
|
assembler = await EmbeddingAssembler.aload_from_knowledge(
|
|
knowledge=knowledge,
|
|
chunk_parameters=chunk_parameters,
|
|
index_store=knowledge_graph,
|
|
retrieve_strategy=RetrieverStrategy.GRAPH,
|
|
)
|
|
await assembler.apersist()
|
|
|
|
# get embeddings retriever
|
|
retriever = assembler.as_retriever(1)
|
|
chunks = await retriever.aretrieve_with_scores(question, score_threshold=0.3)
|
|
|
|
# chat
|
|
print(f"{await ask_chunk(chunks[0], question)}")
|
|
|
|
finally:
|
|
knowledge_graph.delete_vector_name(knowledge_graph.get_config().name)
|