mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 02:03:44 +00:00
Sanitize backticks from neo4j labels and types for import (#23367)
This commit is contained in:
parent
d2db561347
commit
aeeda370aa
@ -287,6 +287,10 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_backticks(text: str) -> str:
|
||||||
|
return text.replace("`", "")
|
||||||
|
|
||||||
|
|
||||||
class Neo4jGraph(GraphStore):
|
class Neo4jGraph(GraphStore):
|
||||||
"""Neo4j database wrapper for various graph operations.
|
"""Neo4j database wrapper for various graph operations.
|
||||||
|
|
||||||
@ -571,6 +575,9 @@ class Neo4jGraph(GraphStore):
|
|||||||
document.source.page_content.encode("utf-8")
|
document.source.page_content.encode("utf-8")
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
|
# Remove backticks from node types
|
||||||
|
for node in document.nodes:
|
||||||
|
node.type = _remove_backticks(node.type)
|
||||||
# Import nodes
|
# Import nodes
|
||||||
self.query(
|
self.query(
|
||||||
node_import_query,
|
node_import_query,
|
||||||
@ -586,10 +593,12 @@ class Neo4jGraph(GraphStore):
|
|||||||
"data": [
|
"data": [
|
||||||
{
|
{
|
||||||
"source": el.source.id,
|
"source": el.source.id,
|
||||||
"source_label": el.source.type,
|
"source_label": _remove_backticks(el.source.type),
|
||||||
"target": el.target.id,
|
"target": el.target.id,
|
||||||
"target_label": el.target.type,
|
"target_label": _remove_backticks(el.target.type),
|
||||||
"type": el.type.replace(" ", "_").upper(),
|
"type": _remove_backticks(
|
||||||
|
el.type.replace(" ", "_").upper()
|
||||||
|
),
|
||||||
"properties": el.properties,
|
"properties": el.properties,
|
||||||
}
|
}
|
||||||
for el in document.relationships
|
for el in document.relationships
|
||||||
|
@ -25,6 +25,20 @@ test_data = [
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
test_data_backticks = [
|
||||||
|
GraphDocument(
|
||||||
|
nodes=[Node(id="foo", type="foo`"), Node(id="bar", type="`bar")],
|
||||||
|
relationships=[
|
||||||
|
Relationship(
|
||||||
|
source=Node(id="foo", type="f`oo"),
|
||||||
|
target=Node(id="bar", type="ba`r"),
|
||||||
|
type="`REL`",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
source=Document(page_content="source document"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_cypher_return_correct_schema() -> None:
|
def test_cypher_return_correct_schema() -> None:
|
||||||
"""Test that chain returns direct results."""
|
"""Test that chain returns direct results."""
|
||||||
@ -363,3 +377,24 @@ def test_enhanced_schema_exception() -> None:
|
|||||||
# remove metadata portion of schema
|
# remove metadata portion of schema
|
||||||
del graph.structured_schema["metadata"]
|
del graph.structured_schema["metadata"]
|
||||||
assert graph.structured_schema == expected_output
|
assert graph.structured_schema == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_backticks() -> None:
|
||||||
|
"""Test that backticks are correctly removed."""
|
||||||
|
url = os.environ.get("NEO4J_URI")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = Neo4jGraph(url=url, username=username, password=password)
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
graph.add_graph_documents(test_data_backticks)
|
||||||
|
nodes = graph.query("MATCH (n) RETURN labels(n) AS labels ORDER BY n.id")
|
||||||
|
rels = graph.query("MATCH ()-[r]->() RETURN type(r) AS type")
|
||||||
|
expected_nodes = [{"labels": ["bar"]}, {"labels": ["foo"]}]
|
||||||
|
expected_rels = [{"type": "REL"}]
|
||||||
|
|
||||||
|
assert nodes == expected_nodes
|
||||||
|
assert rels == expected_rels
|
||||||
|
Loading…
Reference in New Issue
Block a user