mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-25 03:20:41 +00:00
feat: add GraphRAG framework and integrate TuGraph (#1506)
Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
0
tests/intetration_tests/graph_store/__init__.py
Normal file
0
tests/intetration_tests/graph_store/__init__.py
Normal file
41
tests/intetration_tests/graph_store/test_memgraph_store.py
Normal file
41
tests/intetration_tests/graph_store/test_memgraph_store.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.storage.graph_store.memgraph_store import (
|
||||
MemoryGraphStore,
|
||||
MemoryGraphStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_store():
|
||||
yield MemoryGraphStore(MemoryGraphStoreConfig())
|
||||
|
||||
|
||||
def test_graph_store(graph_store):
|
||||
graph_store.insert_triplet("A", "0", "A")
|
||||
graph_store.insert_triplet("A", "1", "A")
|
||||
graph_store.insert_triplet("A", "2", "B")
|
||||
graph_store.insert_triplet("B", "3", "C")
|
||||
graph_store.insert_triplet("B", "4", "D")
|
||||
graph_store.insert_triplet("C", "5", "D")
|
||||
graph_store.insert_triplet("B", "6", "E")
|
||||
graph_store.insert_triplet("F", "7", "E")
|
||||
graph_store.insert_triplet("E", "8", "F")
|
||||
|
||||
subgraph = graph_store.explore(["A"])
|
||||
print(f"\n{subgraph.graphviz()}")
|
||||
assert subgraph.edge_count == 9
|
||||
|
||||
graph_store.delete_triplet("A", "0", "A")
|
||||
graph_store.delete_triplet("B", "4", "D")
|
||||
subgraph = graph_store.explore(["A"])
|
||||
print(f"\n{subgraph.graphviz()}")
|
||||
assert subgraph.edge_count == 7
|
||||
|
||||
triplets = graph_store.get_triplets("B")
|
||||
print(f"\nTriplets of B: {triplets}")
|
||||
assert len(triplets) == 2
|
||||
|
||||
schema = graph_store.get_schema()
|
||||
print(f"\nSchema: {schema}")
|
||||
assert len(schema) == 138
|
67
tests/intetration_tests/graph_store/test_tugraph_store.py
Normal file
67
tests/intetration_tests/graph_store/test_tugraph_store.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# test_tugraph_store.py
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
|
||||
|
||||
|
||||
class TuGraphStoreConfig:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def store():
|
||||
config = TuGraphStoreConfig(name="TestGraph")
|
||||
store = TuGraphStore(config=config)
|
||||
yield store
|
||||
store.conn.close()
|
||||
|
||||
|
||||
def test_insert_and_get_triplets(store):
|
||||
store.insert_triplet("A", "0", "A")
|
||||
store.insert_triplet("A", "1", "A")
|
||||
store.insert_triplet("A", "2", "B")
|
||||
store.insert_triplet("B", "3", "C")
|
||||
store.insert_triplet("B", "4", "D")
|
||||
store.insert_triplet("C", "5", "D")
|
||||
store.insert_triplet("B", "6", "E")
|
||||
store.insert_triplet("F", "7", "E")
|
||||
store.insert_triplet("E", "8", "F")
|
||||
triplets = store.get_triplets("A")
|
||||
assert len(triplets) == 3
|
||||
triplets = store.get_triplets("B")
|
||||
assert len(triplets) == 3
|
||||
triplets = store.get_triplets("C")
|
||||
assert len(triplets) == 1
|
||||
triplets = store.get_triplets("D")
|
||||
assert len(triplets) == 0
|
||||
triplets = store.get_triplets("E")
|
||||
assert len(triplets) == 1
|
||||
triplets = store.get_triplets("F")
|
||||
assert len(triplets) == 1
|
||||
|
||||
|
||||
def test_query(store):
|
||||
query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3"
|
||||
result = store.query(query)
|
||||
v_c = result.vertex_count
|
||||
e_c = result.edge_count
|
||||
assert v_c == 2 and e_c == 3
|
||||
|
||||
|
||||
def test_explore(store):
|
||||
subs = ["A", "B"]
|
||||
result = store.explore(subs, depth=2, fan=None, limit=10)
|
||||
v_c = result.vertex_count
|
||||
e_c = result.edge_count
|
||||
assert v_c == 2 and e_c == 3
|
||||
|
||||
|
||||
# def test_delete_triplet(store):
|
||||
# subj = "A"
|
||||
# rel = "0"
|
||||
# obj = "B"
|
||||
# store.delete_triplet(subj, rel, obj)
|
||||
# triplets = store.get_triplets(subj)
|
||||
# assert len(triplets) == 0
|
42
tests/intetration_tests/transformer/test_extactor.py
Normal file
42
tests/intetration_tests/transformer/test_extactor.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
|
||||
from dbgpt.rag.transformer.triplet_extractor import TripletExtractor
|
||||
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
yield OpenAILLMClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def triplet_extractor(llm):
|
||||
yield TripletExtractor(llm, model_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def keyword_extractor(llm):
|
||||
yield KeywordExtractor(llm, model_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_triplet(triplet_extractor):
|
||||
triplets = await triplet_extractor.extract(
|
||||
"Alice is Bob and Cherry's mother and lives in New York.", 10
|
||||
)
|
||||
print(json.dumps(triplets))
|
||||
assert len(triplets) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_keyword(keyword_extractor):
|
||||
keywords = await keyword_extractor.extract(
|
||||
"Alice is Bob and Cherry's mother and lives in New York.",
|
||||
)
|
||||
print(json.dumps(keywords))
|
||||
assert len(keywords) > 0
|
Reference in New Issue
Block a user