mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
Sanitize backticks from neo4j labels and types for import (#23367)
This commit is contained in:
parent
d2db561347
commit
aeeda370aa
@ -287,6 +287,10 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _remove_backticks(text: str) -> str:
|
||||
return text.replace("`", "")
|
||||
|
||||
|
||||
class Neo4jGraph(GraphStore):
|
||||
"""Neo4j database wrapper for various graph operations.
|
||||
|
||||
@ -571,6 +575,9 @@ class Neo4jGraph(GraphStore):
|
||||
document.source.page_content.encode("utf-8")
|
||||
).hexdigest()
|
||||
|
||||
# Remove backticks from node types
|
||||
for node in document.nodes:
|
||||
node.type = _remove_backticks(node.type)
|
||||
# Import nodes
|
||||
self.query(
|
||||
node_import_query,
|
||||
@ -586,10 +593,12 @@ class Neo4jGraph(GraphStore):
|
||||
"data": [
|
||||
{
|
||||
"source": el.source.id,
|
||||
"source_label": el.source.type,
|
||||
"source_label": _remove_backticks(el.source.type),
|
||||
"target": el.target.id,
|
||||
"target_label": el.target.type,
|
||||
"type": el.type.replace(" ", "_").upper(),
|
||||
"target_label": _remove_backticks(el.target.type),
|
||||
"type": _remove_backticks(
|
||||
el.type.replace(" ", "_").upper()
|
||||
),
|
||||
"properties": el.properties,
|
||||
}
|
||||
for el in document.relationships
|
||||
|
@ -25,6 +25,20 @@ test_data = [
|
||||
)
|
||||
]
|
||||
|
||||
test_data_backticks = [
|
||||
GraphDocument(
|
||||
nodes=[Node(id="foo", type="foo`"), Node(id="bar", type="`bar")],
|
||||
relationships=[
|
||||
Relationship(
|
||||
source=Node(id="foo", type="f`oo"),
|
||||
target=Node(id="bar", type="ba`r"),
|
||||
type="`REL`",
|
||||
)
|
||||
],
|
||||
source=Document(page_content="source document"),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_cypher_return_correct_schema() -> None:
|
||||
"""Test that chain returns direct results."""
|
||||
@ -363,3 +377,24 @@ def test_enhanced_schema_exception() -> None:
|
||||
# remove metadata portion of schema
|
||||
del graph.structured_schema["metadata"]
|
||||
assert graph.structured_schema == expected_output
|
||||
|
||||
|
||||
def test_backticks() -> None:
|
||||
"""Test that backticks are correctly removed."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(url=url, username=username, password=password)
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
graph.add_graph_documents(test_data_backticks)
|
||||
nodes = graph.query("MATCH (n) RETURN labels(n) AS labels ORDER BY n.id")
|
||||
rels = graph.query("MATCH ()-[r]->() RETURN type(r) AS type")
|
||||
expected_nodes = [{"labels": ["bar"]}, {"labels": ["foo"]}]
|
||||
expected_rels = [{"type": "REL"}]
|
||||
|
||||
assert nodes == expected_nodes
|
||||
assert rels == expected_rels
|
||||
|
Loading…
Reference in New Issue
Block a user