diff --git a/libs/langchain/langchain/graphs/__init__.py b/libs/langchain/langchain/graphs/__init__.py index e0b2d639eac..9699750db2f 100644 --- a/libs/langchain/langchain/graphs/__init__.py +++ b/libs/langchain/langchain/graphs/__init__.py @@ -3,6 +3,7 @@ from langchain.graphs.arangodb_graph import ArangoGraph from langchain.graphs.hugegraph import HugeGraph from langchain.graphs.kuzu_graph import KuzuGraph +from langchain.graphs.memgraph_graph import MemgraphGraph from langchain.graphs.nebula_graph import NebulaGraph from langchain.graphs.neo4j_graph import Neo4jGraph from langchain.graphs.neptune_graph import NeptuneGraph @@ -10,6 +11,7 @@ from langchain.graphs.networkx_graph import NetworkxEntityGraph from langchain.graphs.rdf_graph import RdfGraph __all__ = [ + "MemgraphGraph", "NetworkxEntityGraph", "Neo4jGraph", "NebulaGraph", diff --git a/libs/langchain/langchain/graphs/memgraph_graph.py b/libs/langchain/langchain/graphs/memgraph_graph.py new file mode 100644 index 00000000000..1a2412dd0eb --- /dev/null +++ b/libs/langchain/langchain/graphs/memgraph_graph.py @@ -0,0 +1,26 @@ +from langchain.graphs.neo4j_graph import Neo4jGraph + +SCHEMA_QUERY = """ +CALL llm_util.schema("prompt_ready") +YIELD * +RETURN * +""" + + +class MemgraphGraph(Neo4jGraph): + """Memgraph wrapper for graph operations.""" + + def __init__( + self, url: str, username: str, password: str, *, database: str = "memgraph" + ) -> None: + """Create a new Memgraph graph wrapper instance.""" + super().__init__(url, username, password, database=database) + + def refresh_schema(self) -> None: + """ + Refreshes the Memgraph graph schema information. + """ + + db_schema = self.query(SCHEMA_QUERY)[0].get("schema") + assert db_schema is not None + self.schema = db_schema diff --git a/libs/langchain/langchain/graphs/neo4j_graph.py b/libs/langchain/langchain/graphs/neo4j_graph.py index c4450043b1d..02572b2d1a1 100644 --- a/libs/langchain/langchain/graphs/neo4j_graph.py +++ b/libs/langchain/langchain/graphs/neo4j_graph.py @@ -81,7 +81,7 @@ class Neo4jGraph: data = session.run(query, params) return [r.data() for r in data] except CypherSyntaxError as e: - raise ValueError("Generated Cypher Statement is not valid\n" f"{e}") + raise ValueError(f"Generated Cypher Statement is not valid\n{e}") def refresh_schema(self) -> None: """