Filtering graph schema for Cypher generation (#10577)

Sometimes you don't want the LLM to be aware of the whole graph schema,
and want it to ignore parts of the graph when it is constructing Cypher
statements.
This commit is contained in:
Tomaz Bratanic 2023-09-25 23:14:15 +02:00 committed by GitHub
parent 89ef440c14
commit 0625ab7a9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 368 additions and 27 deletions

View File

@ -135,7 +135,7 @@
}
],
"source": [
"print(graph.get_schema)"
"print(graph.schema)"
]
},
{
@ -510,13 +510,54 @@
"chain.run(\"Who played in Top Gun?\")"
]
},
{
"cell_type": "markdown",
"id": "eefea16b-508f-4552-8942-9d5063ed7d37",
"metadata": {},
"source": [
"# Ignore specified node and relationship types\n",
"You can use `include_types` or `exclude_types` to ignore parts of the graph schema when generating Cypher statements."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48ff7cf8-18a3-43d7-8cb1-c1b91744608d",
"execution_count": 18,
"id": "a20fa21e-fb85-41c4-aac0-53fb25e34604",
"metadata": {},
"outputs": [],
"source": []
"source": [
"chain = GraphCypherQAChain.from_llm(\n",
" graph=graph,\n",
" cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"),\n",
" qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-16k\"),\n",
" verbose=True,\n",
" exclude_types=['Movie']\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "3ad7f6b8-543e-46e4-a3b2-40fa3e66e895",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Node properties are the following: \n",
" {'Actor': [{'property': 'name', 'type': 'STRING'}]}\n",
"Relationships properties are the following: \n",
" {}\n",
"Relationships are: \n",
"[]\n"
]
}
],
"source": [
"# Inspect graph schema\n",
"print(chain.graph_schema)"
]
}
],
"metadata": {

View File

@ -187,7 +187,7 @@
"metadata": {},
"outputs": [],
"source": [
"print(graph.get_schema)"
"print(graph.schema)"
]
},
{
@ -687,7 +687,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.8.8"
}
},
"nbformat": 4,

View File

@ -34,12 +34,54 @@ def extract_cypher(text: str) -> str:
return matches[0] if matches else text
def construct_schema(
structured_schema: Dict[str, Any],
include_types: List[str],
exclude_types: List[str],
) -> str:
"""Filter the schema based on included or excluded types"""
def filter_func(x: str) -> bool:
return x in include_types if include_types else x not in exclude_types
filtered_schema = {
"node_props": {
k: v
for k, v in structured_schema.get("node_props", {}).items()
if filter_func(k)
},
"rel_props": {
k: v
for k, v in structured_schema.get("rel_props", {}).items()
if filter_func(k)
},
"relationships": [
r
for r in structured_schema.get("relationships", [])
if all(filter_func(r[t]) for t in ["start", "end", "type"])
],
}
return (
f"Node properties are the following: \n {filtered_schema['node_props']}\n"
f"Relationships properties are the following: \n {filtered_schema['rel_props']}"
"\nRelationships are: \n"
+ str(
[
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
for el in filtered_schema["relationships"]
]
)
)
class GraphCypherQAChain(Chain):
"""Chain for question-answering against a graph by generating Cypher statements."""
graph: Neo4jGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
@ -79,6 +121,8 @@ class GraphCypherQAChain(Chain):
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
cypher_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[BaseLanguageModel] = None,
exclude_types: List[str] = [],
include_types: List[str] = [],
**kwargs: Any,
) -> GraphCypherQAChain:
"""Initialize from LLM."""
@ -96,7 +140,18 @@ class GraphCypherQAChain(Chain):
qa_chain = LLMChain(llm=qa_llm or llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=cypher_llm or llm, prompt=cypher_prompt)
if exclude_types and include_types:
raise ValueError(
"Either `exclude_types` or `include_types` "
"can be provided, but not both"
)
graph_schema = construct_schema(
kwargs["graph"].structured_schema, include_types, exclude_types
)
return cls(
graph_schema=graph_schema,
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
**kwargs,
@ -115,7 +170,7 @@ class GraphCypherQAChain(Chain):
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
{"question": question, "schema": self.graph_schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in backticks

View File

@ -6,6 +6,12 @@ YIELD *
RETURN *
"""
RAW_SCHEMA_QUERY = """
CALL llm_util.schema("raw")
YIELD *
RETURN *
"""
class MemgraphGraph(Neo4jGraph):
"""Memgraph wrapper for graph operations."""
@ -24,3 +30,7 @@ class MemgraphGraph(Neo4jGraph):
db_schema = self.query(SCHEMA_QUERY)[0].get("schema")
assert db_schema is not None
self.schema = db_schema
db_structured_schema = self.query(RAW_SCHEMA_QUERY)[0].get("schema")
assert db_structured_schema is not None
self.structured_schema = db_structured_schema

View File

@ -24,7 +24,7 @@ CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE type = "RELATIONSHIP" AND elementType = "node"
UNWIND other AS other_node
RETURN "(:" + label + ")-[:" + property + "]->(:" + toString(other_node) + ")" AS output
RETURN {start: label, type: property, end: toString(other_node)} AS output
"""
@ -45,7 +45,8 @@ class Neo4jGraph:
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
self.schema = ""
self.schema: str = ""
self.structured_schema: Dict[str, Any] = {}
# Verify connection
try:
self._driver.verify_connectivity()
@ -69,11 +70,6 @@ class Neo4jGraph:
"'apoc.meta.data()' is allowed in Neo4j configuration "
)
@property
def get_schema(self) -> str:
"""Returns the schema of the Neo4j database"""
return self.schema
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query Neo4j database."""
from neo4j.exceptions import CypherSyntaxError
@ -89,17 +85,22 @@ class Neo4jGraph:
"""
Refreshes the Neo4j graph schema information.
"""
node_properties = self.query(node_properties_query)
relationships_properties = self.query(rel_properties_query)
relationships = self.query(rel_query)
node_properties = [el["output"] for el in self.query(node_properties_query)]
rel_properties = [el["output"] for el in self.query(rel_properties_query)]
relationships = [el["output"] for el in self.query(rel_query)]
self.structured_schema = {
"node_props": {el["labels"]: el["properties"] for el in node_properties},
"rel_props": {el["type"]: el["properties"] for el in rel_properties},
"relationships": relationships,
}
self.schema = f"""
Node properties are the following:
{[el['output'] for el in node_properties]}
{node_properties}
Relationship properties are the following:
{[el['output'] for el in relationships_properties]}
{rel_properties}
The relationships are the following:
{[el['output'] for el in relationships]}
{[f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" for el in relationships]}
"""
def add_graph_documents(

View File

@ -211,21 +211,32 @@ def test_cypher_return_correct_schema() -> None:
expected_node_properties = [
{
"output": {
"properties": [{"property": "property_a", "type": "STRING"}],
"labels": "LabelA",
}
}
]
expected_relationships_properties = [
{"type": "REL_TYPE", "properties": [{"property": "rel_prop", "type": "STRING"}]}
{
"output": {
"type": "REL_TYPE",
"properties": [{"property": "rel_prop", "type": "STRING"}],
}
}
]
expected_relationships = [
"(:LabelA)-[:REL_TYPE]->(:LabelB)",
"(:LabelA)-[:REL_TYPE]->(:LabelC)",
{"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}},
{"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}},
]
assert node_properties == expected_node_properties
assert relationships_properties == expected_relationships_properties
assert relationships == expected_relationships
# Order is not guaranteed with Neo4j returns
assert (
sorted(relationships, key=lambda x: x["output"]["end"])
== expected_relationships
)
def test_cypher_save_load() -> None:
@ -252,3 +263,122 @@ def test_cypher_save_load() -> None:
qa_loaded = load_chain(FILE_PATH, graph=graph)
assert qa_loaded == chain
def test_exclude_types() -> None:
"""Test exclude types from schema."""
url = os.environ.get("NEO4J_URL")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
"<-[:DIRECTED]-(p:Person {name:'John'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, exclude_types=["Person", "DIRECTED"]
)
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
)
assert chain.graph_schema == expected_schema
def test_include_types() -> None:
"""Test include types from schema."""
url = os.environ.get("NEO4J_URL")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
"<-[:DIRECTED]-(p:Person {name:'John'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, include_types=["Movie", "Actor", "ACTED_IN"]
)
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
)
assert chain.graph_schema == expected_schema
def test_include_types2() -> None:
"""Test include types from schema."""
url = os.environ.get("NEO4J_URL")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
"<-[:DIRECTED]-(p:Person {name:'John'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, include_types=["Movie", "ACTED_IN"]
)
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"[]"
)
assert chain.graph_schema == expected_schema

View File

@ -1,4 +1,4 @@
from langchain.chains.graph_qa.cypher import extract_cypher
from langchain.chains.graph_qa.cypher import construct_schema, extract_cypher
def test_no_backticks() -> None:
@ -13,3 +13,107 @@ def test_backticks() -> None:
query = "You can use the following query: ```MATCH (n) RETURN n```"
output = extract_cypher(query)
assert output == "MATCH (n) RETURN n"
def test_exclude_types() -> None:
structured_schema = {
"node_props": {
"Movie": [{"property": "title", "type": "STRING"}],
"Actor": [{"property": "name", "type": "STRING"}],
"Person": [{"property": "name", "type": "STRING"}],
},
"rel_props": {},
"relationships": [
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
],
}
exclude_types = ["Person", "DIRECTED"]
output = construct_schema(structured_schema, [], exclude_types)
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
)
assert output == expected_schema
def test_include_types() -> None:
structured_schema = {
"node_props": {
"Movie": [{"property": "title", "type": "STRING"}],
"Actor": [{"property": "name", "type": "STRING"}],
"Person": [{"property": "name", "type": "STRING"}],
},
"rel_props": {},
"relationships": [
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
],
}
include_types = ["Movie", "Actor", "ACTED_IN"]
output = construct_schema(structured_schema, include_types, [])
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
)
assert output == expected_schema
def test_include_types2() -> None:
structured_schema = {
"node_props": {
"Movie": [{"property": "title", "type": "STRING"}],
"Actor": [{"property": "name", "type": "STRING"}],
"Person": [{"property": "name", "type": "STRING"}],
},
"rel_props": {},
"relationships": [
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
],
}
include_types = ["Movie", "Actor"]
output = construct_schema(structured_schema, include_types, [])
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"[]"
)
assert output == expected_schema
def test_include_types3() -> None:
structured_schema = {
"node_props": {
"Movie": [{"property": "title", "type": "STRING"}],
"Actor": [{"property": "name", "type": "STRING"}],
"Person": [{"property": "name", "type": "STRING"}],
},
"rel_props": {},
"relationships": [
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
],
}
include_types = ["Movie", "Actor", "ACTED_IN"]
output = construct_schema(structured_schema, include_types, [])
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
)
assert output == expected_schema