mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 19:18:53 +00:00
Filtering graph schema for Cypher generation (#10577)
Sometimes you don't want the LLM to be aware of the whole graph schema, and want it to ignore parts of the graph when it is constructing Cypher statements.
This commit is contained in:
parent
89ef440c14
commit
0625ab7a9e
@ -135,7 +135,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"print(graph.get_schema)"
|
"print(graph.schema)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -510,13 +510,54 @@
|
|||||||
"chain.run(\"Who played in Top Gun?\")"
|
"chain.run(\"Who played in Top Gun?\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "eefea16b-508f-4552-8942-9d5063ed7d37",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Ignore specified node and relationship types\n",
|
||||||
|
"You can use `include_types` or `exclude_types` to ignore parts of the graph schema when generating Cypher statements."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 18,
|
||||||
"id": "48ff7cf8-18a3-43d7-8cb1-c1b91744608d",
|
"id": "a20fa21e-fb85-41c4-aac0-53fb25e34604",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": [
|
||||||
|
"chain = GraphCypherQAChain.from_llm(\n",
|
||||||
|
" graph=graph,\n",
|
||||||
|
" cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"),\n",
|
||||||
|
" qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-16k\"),\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
" exclude_types=['Movie']\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"id": "3ad7f6b8-543e-46e4-a3b2-40fa3e66e895",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Node properties are the following: \n",
|
||||||
|
" {'Actor': [{'property': 'name', 'type': 'STRING'}]}\n",
|
||||||
|
"Relationships properties are the following: \n",
|
||||||
|
" {}\n",
|
||||||
|
"Relationships are: \n",
|
||||||
|
"[]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Inspect graph schema\n",
|
||||||
|
"print(chain.graph_schema)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -187,7 +187,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"print(graph.get_schema)"
|
"print(graph.schema)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -687,7 +687,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.13"
|
"version": "3.8.8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -34,12 +34,54 @@ def extract_cypher(text: str) -> str:
|
|||||||
return matches[0] if matches else text
|
return matches[0] if matches else text
|
||||||
|
|
||||||
|
|
||||||
|
def construct_schema(
|
||||||
|
structured_schema: Dict[str, Any],
|
||||||
|
include_types: List[str],
|
||||||
|
exclude_types: List[str],
|
||||||
|
) -> str:
|
||||||
|
"""Filter the schema based on included or excluded types"""
|
||||||
|
|
||||||
|
def filter_func(x: str) -> bool:
|
||||||
|
return x in include_types if include_types else x not in exclude_types
|
||||||
|
|
||||||
|
filtered_schema = {
|
||||||
|
"node_props": {
|
||||||
|
k: v
|
||||||
|
for k, v in structured_schema.get("node_props", {}).items()
|
||||||
|
if filter_func(k)
|
||||||
|
},
|
||||||
|
"rel_props": {
|
||||||
|
k: v
|
||||||
|
for k, v in structured_schema.get("rel_props", {}).items()
|
||||||
|
if filter_func(k)
|
||||||
|
},
|
||||||
|
"relationships": [
|
||||||
|
r
|
||||||
|
for r in structured_schema.get("relationships", [])
|
||||||
|
if all(filter_func(r[t]) for t in ["start", "end", "type"])
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"Node properties are the following: \n {filtered_schema['node_props']}\n"
|
||||||
|
f"Relationships properties are the following: \n {filtered_schema['rel_props']}"
|
||||||
|
"\nRelationships are: \n"
|
||||||
|
+ str(
|
||||||
|
[
|
||||||
|
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
|
||||||
|
for el in filtered_schema["relationships"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GraphCypherQAChain(Chain):
|
class GraphCypherQAChain(Chain):
|
||||||
"""Chain for question-answering against a graph by generating Cypher statements."""
|
"""Chain for question-answering against a graph by generating Cypher statements."""
|
||||||
|
|
||||||
graph: Neo4jGraph = Field(exclude=True)
|
graph: Neo4jGraph = Field(exclude=True)
|
||||||
cypher_generation_chain: LLMChain
|
cypher_generation_chain: LLMChain
|
||||||
qa_chain: LLMChain
|
qa_chain: LLMChain
|
||||||
|
graph_schema: str
|
||||||
input_key: str = "query" #: :meta private:
|
input_key: str = "query" #: :meta private:
|
||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
top_k: int = 10
|
top_k: int = 10
|
||||||
@ -79,6 +121,8 @@ class GraphCypherQAChain(Chain):
|
|||||||
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
|
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
|
||||||
cypher_llm: Optional[BaseLanguageModel] = None,
|
cypher_llm: Optional[BaseLanguageModel] = None,
|
||||||
qa_llm: Optional[BaseLanguageModel] = None,
|
qa_llm: Optional[BaseLanguageModel] = None,
|
||||||
|
exclude_types: List[str] = [],
|
||||||
|
include_types: List[str] = [],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> GraphCypherQAChain:
|
) -> GraphCypherQAChain:
|
||||||
"""Initialize from LLM."""
|
"""Initialize from LLM."""
|
||||||
@ -96,7 +140,18 @@ class GraphCypherQAChain(Chain):
|
|||||||
qa_chain = LLMChain(llm=qa_llm or llm, prompt=qa_prompt)
|
qa_chain = LLMChain(llm=qa_llm or llm, prompt=qa_prompt)
|
||||||
cypher_generation_chain = LLMChain(llm=cypher_llm or llm, prompt=cypher_prompt)
|
cypher_generation_chain = LLMChain(llm=cypher_llm or llm, prompt=cypher_prompt)
|
||||||
|
|
||||||
|
if exclude_types and include_types:
|
||||||
|
raise ValueError(
|
||||||
|
"Either `exclude_types` or `include_types` "
|
||||||
|
"can be provided, but not both"
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_schema = construct_schema(
|
||||||
|
kwargs["graph"].structured_schema, include_types, exclude_types
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
graph_schema=graph_schema,
|
||||||
qa_chain=qa_chain,
|
qa_chain=qa_chain,
|
||||||
cypher_generation_chain=cypher_generation_chain,
|
cypher_generation_chain=cypher_generation_chain,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -115,7 +170,7 @@ class GraphCypherQAChain(Chain):
|
|||||||
intermediate_steps: List = []
|
intermediate_steps: List = []
|
||||||
|
|
||||||
generated_cypher = self.cypher_generation_chain.run(
|
generated_cypher = self.cypher_generation_chain.run(
|
||||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
{"question": question, "schema": self.graph_schema}, callbacks=callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract Cypher code if it is wrapped in backticks
|
# Extract Cypher code if it is wrapped in backticks
|
||||||
|
@ -6,6 +6,12 @@ YIELD *
|
|||||||
RETURN *
|
RETURN *
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
RAW_SCHEMA_QUERY = """
|
||||||
|
CALL llm_util.schema("raw")
|
||||||
|
YIELD *
|
||||||
|
RETURN *
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class MemgraphGraph(Neo4jGraph):
|
class MemgraphGraph(Neo4jGraph):
|
||||||
"""Memgraph wrapper for graph operations."""
|
"""Memgraph wrapper for graph operations."""
|
||||||
@ -24,3 +30,7 @@ class MemgraphGraph(Neo4jGraph):
|
|||||||
db_schema = self.query(SCHEMA_QUERY)[0].get("schema")
|
db_schema = self.query(SCHEMA_QUERY)[0].get("schema")
|
||||||
assert db_schema is not None
|
assert db_schema is not None
|
||||||
self.schema = db_schema
|
self.schema = db_schema
|
||||||
|
|
||||||
|
db_structured_schema = self.query(RAW_SCHEMA_QUERY)[0].get("schema")
|
||||||
|
assert db_structured_schema is not None
|
||||||
|
self.structured_schema = db_structured_schema
|
||||||
|
@ -24,7 +24,7 @@ CALL apoc.meta.data()
|
|||||||
YIELD label, other, elementType, type, property
|
YIELD label, other, elementType, type, property
|
||||||
WHERE type = "RELATIONSHIP" AND elementType = "node"
|
WHERE type = "RELATIONSHIP" AND elementType = "node"
|
||||||
UNWIND other AS other_node
|
UNWIND other AS other_node
|
||||||
RETURN "(:" + label + ")-[:" + property + "]->(:" + toString(other_node) + ")" AS output
|
RETURN {start: label, type: property, end: toString(other_node)} AS output
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +45,8 @@ class Neo4jGraph:
|
|||||||
|
|
||||||
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
||||||
self._database = database
|
self._database = database
|
||||||
self.schema = ""
|
self.schema: str = ""
|
||||||
|
self.structured_schema: Dict[str, Any] = {}
|
||||||
# Verify connection
|
# Verify connection
|
||||||
try:
|
try:
|
||||||
self._driver.verify_connectivity()
|
self._driver.verify_connectivity()
|
||||||
@ -69,11 +70,6 @@ class Neo4jGraph:
|
|||||||
"'apoc.meta.data()' is allowed in Neo4j configuration "
|
"'apoc.meta.data()' is allowed in Neo4j configuration "
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def get_schema(self) -> str:
|
|
||||||
"""Returns the schema of the Neo4j database"""
|
|
||||||
return self.schema
|
|
||||||
|
|
||||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||||
"""Query Neo4j database."""
|
"""Query Neo4j database."""
|
||||||
from neo4j.exceptions import CypherSyntaxError
|
from neo4j.exceptions import CypherSyntaxError
|
||||||
@ -89,17 +85,22 @@ class Neo4jGraph:
|
|||||||
"""
|
"""
|
||||||
Refreshes the Neo4j graph schema information.
|
Refreshes the Neo4j graph schema information.
|
||||||
"""
|
"""
|
||||||
node_properties = self.query(node_properties_query)
|
node_properties = [el["output"] for el in self.query(node_properties_query)]
|
||||||
relationships_properties = self.query(rel_properties_query)
|
rel_properties = [el["output"] for el in self.query(rel_properties_query)]
|
||||||
relationships = self.query(rel_query)
|
relationships = [el["output"] for el in self.query(rel_query)]
|
||||||
|
|
||||||
|
self.structured_schema = {
|
||||||
|
"node_props": {el["labels"]: el["properties"] for el in node_properties},
|
||||||
|
"rel_props": {el["type"]: el["properties"] for el in rel_properties},
|
||||||
|
"relationships": relationships,
|
||||||
|
}
|
||||||
self.schema = f"""
|
self.schema = f"""
|
||||||
Node properties are the following:
|
Node properties are the following:
|
||||||
{[el['output'] for el in node_properties]}
|
{node_properties}
|
||||||
Relationship properties are the following:
|
Relationship properties are the following:
|
||||||
{[el['output'] for el in relationships_properties]}
|
{rel_properties}
|
||||||
The relationships are the following:
|
The relationships are the following:
|
||||||
{[el['output'] for el in relationships]}
|
{[f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" for el in relationships]}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def add_graph_documents(
|
def add_graph_documents(
|
||||||
|
@ -211,21 +211,32 @@ def test_cypher_return_correct_schema() -> None:
|
|||||||
|
|
||||||
expected_node_properties = [
|
expected_node_properties = [
|
||||||
{
|
{
|
||||||
"properties": [{"property": "property_a", "type": "STRING"}],
|
"output": {
|
||||||
"labels": "LabelA",
|
"properties": [{"property": "property_a", "type": "STRING"}],
|
||||||
|
"labels": "LabelA",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
expected_relationships_properties = [
|
expected_relationships_properties = [
|
||||||
{"type": "REL_TYPE", "properties": [{"property": "rel_prop", "type": "STRING"}]}
|
{
|
||||||
|
"output": {
|
||||||
|
"type": "REL_TYPE",
|
||||||
|
"properties": [{"property": "rel_prop", "type": "STRING"}],
|
||||||
|
}
|
||||||
|
}
|
||||||
]
|
]
|
||||||
expected_relationships = [
|
expected_relationships = [
|
||||||
"(:LabelA)-[:REL_TYPE]->(:LabelB)",
|
{"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}},
|
||||||
"(:LabelA)-[:REL_TYPE]->(:LabelC)",
|
{"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}},
|
||||||
]
|
]
|
||||||
|
|
||||||
assert node_properties == expected_node_properties
|
assert node_properties == expected_node_properties
|
||||||
assert relationships_properties == expected_relationships_properties
|
assert relationships_properties == expected_relationships_properties
|
||||||
assert relationships == expected_relationships
|
# Order is not guaranteed with Neo4j returns
|
||||||
|
assert (
|
||||||
|
sorted(relationships, key=lambda x: x["output"]["end"])
|
||||||
|
== expected_relationships
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_cypher_save_load() -> None:
|
def test_cypher_save_load() -> None:
|
||||||
@ -252,3 +263,122 @@ def test_cypher_save_load() -> None:
|
|||||||
qa_loaded = load_chain(FILE_PATH, graph=graph)
|
qa_loaded = load_chain(FILE_PATH, graph=graph)
|
||||||
|
|
||||||
assert qa_loaded == chain
|
assert qa_loaded == chain
|
||||||
|
|
||||||
|
|
||||||
|
def test_exclude_types() -> None:
|
||||||
|
"""Test exclude types from schema."""
|
||||||
|
url = os.environ.get("NEO4J_URL")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = Neo4jGraph(
|
||||||
|
url=url,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
)
|
||||||
|
# Delete all nodes in the graph
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
# Create two nodes and a relationship
|
||||||
|
graph.query(
|
||||||
|
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||||
|
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||||
|
"<-[:DIRECTED]-(p:Person {name:'John'})"
|
||||||
|
)
|
||||||
|
# Refresh schema information
|
||||||
|
graph.refresh_schema()
|
||||||
|
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
OpenAI(temperature=0), graph=graph, exclude_types=["Person", "DIRECTED"]
|
||||||
|
)
|
||||||
|
expected_schema = (
|
||||||
|
"Node properties are the following: \n"
|
||||||
|
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
|
||||||
|
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
|
||||||
|
"Relationships properties are the following: \n"
|
||||||
|
" {}\nRelationships are: \n"
|
||||||
|
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chain.graph_schema == expected_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_include_types() -> None:
|
||||||
|
"""Test include types from schema."""
|
||||||
|
url = os.environ.get("NEO4J_URL")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = Neo4jGraph(
|
||||||
|
url=url,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
)
|
||||||
|
# Delete all nodes in the graph
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
# Create two nodes and a relationship
|
||||||
|
graph.query(
|
||||||
|
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||||
|
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||||
|
"<-[:DIRECTED]-(p:Person {name:'John'})"
|
||||||
|
)
|
||||||
|
# Refresh schema information
|
||||||
|
graph.refresh_schema()
|
||||||
|
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
OpenAI(temperature=0), graph=graph, include_types=["Movie", "Actor", "ACTED_IN"]
|
||||||
|
)
|
||||||
|
expected_schema = (
|
||||||
|
"Node properties are the following: \n"
|
||||||
|
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
|
||||||
|
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
|
||||||
|
"Relationships properties are the following: \n"
|
||||||
|
" {}\nRelationships are: \n"
|
||||||
|
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chain.graph_schema == expected_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_include_types2() -> None:
|
||||||
|
"""Test include types from schema."""
|
||||||
|
url = os.environ.get("NEO4J_URL")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = Neo4jGraph(
|
||||||
|
url=url,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
)
|
||||||
|
# Delete all nodes in the graph
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
# Create two nodes and a relationship
|
||||||
|
graph.query(
|
||||||
|
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||||
|
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||||
|
"<-[:DIRECTED]-(p:Person {name:'John'})"
|
||||||
|
)
|
||||||
|
# Refresh schema information
|
||||||
|
graph.refresh_schema()
|
||||||
|
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
OpenAI(temperature=0), graph=graph, include_types=["Movie", "ACTED_IN"]
|
||||||
|
)
|
||||||
|
expected_schema = (
|
||||||
|
"Node properties are the following: \n"
|
||||||
|
" {'Movie': [{'property': 'title', 'type': 'STRING'}]}\n"
|
||||||
|
"Relationships properties are the following: \n"
|
||||||
|
" {}\nRelationships are: \n"
|
||||||
|
"[]"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chain.graph_schema == expected_schema
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from langchain.chains.graph_qa.cypher import extract_cypher
|
from langchain.chains.graph_qa.cypher import construct_schema, extract_cypher
|
||||||
|
|
||||||
|
|
||||||
def test_no_backticks() -> None:
|
def test_no_backticks() -> None:
|
||||||
@ -13,3 +13,107 @@ def test_backticks() -> None:
|
|||||||
query = "You can use the following query: ```MATCH (n) RETURN n```"
|
query = "You can use the following query: ```MATCH (n) RETURN n```"
|
||||||
output = extract_cypher(query)
|
output = extract_cypher(query)
|
||||||
assert output == "MATCH (n) RETURN n"
|
assert output == "MATCH (n) RETURN n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_exclude_types() -> None:
|
||||||
|
structured_schema = {
|
||||||
|
"node_props": {
|
||||||
|
"Movie": [{"property": "title", "type": "STRING"}],
|
||||||
|
"Actor": [{"property": "name", "type": "STRING"}],
|
||||||
|
"Person": [{"property": "name", "type": "STRING"}],
|
||||||
|
},
|
||||||
|
"rel_props": {},
|
||||||
|
"relationships": [
|
||||||
|
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
|
||||||
|
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
exclude_types = ["Person", "DIRECTED"]
|
||||||
|
output = construct_schema(structured_schema, [], exclude_types)
|
||||||
|
expected_schema = (
|
||||||
|
"Node properties are the following: \n"
|
||||||
|
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
|
||||||
|
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
|
||||||
|
"Relationships properties are the following: \n"
|
||||||
|
" {}\nRelationships are: \n"
|
||||||
|
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
|
||||||
|
)
|
||||||
|
assert output == expected_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_include_types() -> None:
|
||||||
|
structured_schema = {
|
||||||
|
"node_props": {
|
||||||
|
"Movie": [{"property": "title", "type": "STRING"}],
|
||||||
|
"Actor": [{"property": "name", "type": "STRING"}],
|
||||||
|
"Person": [{"property": "name", "type": "STRING"}],
|
||||||
|
},
|
||||||
|
"rel_props": {},
|
||||||
|
"relationships": [
|
||||||
|
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
|
||||||
|
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
include_types = ["Movie", "Actor", "ACTED_IN"]
|
||||||
|
output = construct_schema(structured_schema, include_types, [])
|
||||||
|
expected_schema = (
|
||||||
|
"Node properties are the following: \n"
|
||||||
|
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
|
||||||
|
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
|
||||||
|
"Relationships properties are the following: \n"
|
||||||
|
" {}\nRelationships are: \n"
|
||||||
|
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
|
||||||
|
)
|
||||||
|
assert output == expected_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_include_types2() -> None:
|
||||||
|
structured_schema = {
|
||||||
|
"node_props": {
|
||||||
|
"Movie": [{"property": "title", "type": "STRING"}],
|
||||||
|
"Actor": [{"property": "name", "type": "STRING"}],
|
||||||
|
"Person": [{"property": "name", "type": "STRING"}],
|
||||||
|
},
|
||||||
|
"rel_props": {},
|
||||||
|
"relationships": [
|
||||||
|
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
|
||||||
|
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
include_types = ["Movie", "Actor"]
|
||||||
|
output = construct_schema(structured_schema, include_types, [])
|
||||||
|
expected_schema = (
|
||||||
|
"Node properties are the following: \n"
|
||||||
|
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
|
||||||
|
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
|
||||||
|
"Relationships properties are the following: \n"
|
||||||
|
" {}\nRelationships are: \n"
|
||||||
|
"[]"
|
||||||
|
)
|
||||||
|
assert output == expected_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_include_types3() -> None:
|
||||||
|
structured_schema = {
|
||||||
|
"node_props": {
|
||||||
|
"Movie": [{"property": "title", "type": "STRING"}],
|
||||||
|
"Actor": [{"property": "name", "type": "STRING"}],
|
||||||
|
"Person": [{"property": "name", "type": "STRING"}],
|
||||||
|
},
|
||||||
|
"rel_props": {},
|
||||||
|
"relationships": [
|
||||||
|
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
|
||||||
|
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
include_types = ["Movie", "Actor", "ACTED_IN"]
|
||||||
|
output = construct_schema(structured_schema, include_types, [])
|
||||||
|
expected_schema = (
|
||||||
|
"Node properties are the following: \n"
|
||||||
|
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
|
||||||
|
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
|
||||||
|
"Relationships properties are the following: \n"
|
||||||
|
" {}\nRelationships are: \n"
|
||||||
|
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
|
||||||
|
)
|
||||||
|
assert output == expected_schema
|
||||||
|
Loading…
Reference in New Issue
Block a user