From 520972fd0fb94ce47e6bd8576b552bd36dd495f8 Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Thu, 25 Apr 2024 20:30:22 +0200 Subject: [PATCH] community[patch]: Support passing graph object to Neo4j integrations (#20876) For driver connection reusage, we introduce passing the graph object to neo4j integrations --- .../chat_message_histories/neo4j.py | 57 ++++++++++------ .../vectorstores/neo4j_vector.py | 66 ++++++++++--------- .../chat_message_histories/test_neo4j.py | 66 +++++++++++++++++++ .../vectorstores/test_neo4jvector.py | 18 +++++ 4 files changed, 156 insertions(+), 51 deletions(-) create mode 100644 libs/community/tests/integration_tests/chat_message_histories/test_neo4j.py diff --git a/libs/community/langchain_community/chat_message_histories/neo4j.py b/libs/community/langchain_community/chat_message_histories/neo4j.py index d64b1e5ed6e..972f284b042 100644 --- a/libs/community/langchain_community/chat_message_histories/neo4j.py +++ b/libs/community/langchain_community/chat_message_histories/neo4j.py @@ -2,7 +2,9 @@ from typing import List, Optional, Union from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage, messages_from_dict -from langchain_core.utils import get_from_env +from langchain_core.utils import get_from_dict_or_env + +from langchain_community.graphs import Neo4jGraph class Neo4jChatMessageHistory(BaseChatMessageHistory): @@ -17,6 +19,8 @@ class Neo4jChatMessageHistory(BaseChatMessageHistory): database: str = "neo4j", node_label: str = "Session", window: int = 3, + *, + graph: Optional[Neo4jGraph] = None, ): try: import neo4j @@ -30,30 +34,41 @@ class Neo4jChatMessageHistory(BaseChatMessageHistory): if not session_id: raise ValueError("Please ensure that the session_id parameter is provided") - url = get_from_env("url", "NEO4J_URI", url) - username = get_from_env("username", "NEO4J_USERNAME", username) - password = get_from_env("password", "NEO4J_PASSWORD", password) - database = get_from_env("database", "NEO4J_DATABASE", database) + # Graph object takes precedent over env or input params + if graph: + self._driver = graph._driver + self._database = graph._database + else: + # Handle if the credentials are environment variables + url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI") + username = get_from_dict_or_env( + {"username": username}, "username", "NEO4J_USERNAME" + ) + password = get_from_dict_or_env( + {"password": password}, "password", "NEO4J_PASSWORD" + ) + database = get_from_dict_or_env( + {"database": database}, "database", "NEO4J_DATABASE", "neo4j" + ) - self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) - self._database = database + self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) + self._database = database + # Verify connection + try: + self._driver.verify_connectivity() + except neo4j.exceptions.ServiceUnavailable: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the url is correct" + ) + except neo4j.exceptions.AuthError: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the username and password are correct" + ) self._session_id = session_id self._node_label = node_label self._window = window - - # Verify connection - try: - self._driver.verify_connectivity() - except neo4j.exceptions.ServiceUnavailable: - raise ValueError( - "Could not connect to Neo4j database. " - "Please ensure that the url is correct" - ) - except neo4j.exceptions.AuthError: - raise ValueError( - "Could not connect to Neo4j database. " - "Please ensure that the username and password are correct" - ) # Create session node self._driver.execute_query( f"MERGE (s:`{self._node_label}` {{id:$session_id}})", diff --git a/libs/community/langchain_community/vectorstores/neo4j_vector.py b/libs/community/langchain_community/vectorstores/neo4j_vector.py index 3a9c57de42a..d09145b1b7f 100644 --- a/libs/community/langchain_community/vectorstores/neo4j_vector.py +++ b/libs/community/langchain_community/vectorstores/neo4j_vector.py @@ -20,6 +20,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore +from langchain_community.graphs import Neo4jGraph from langchain_community.vectorstores.utils import DistanceStrategy DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE @@ -483,6 +484,7 @@ class Neo4jVector(VectorStore): retrieval_query: str = "", relevance_score_fn: Optional[Callable[[float], float]] = None, index_type: IndexType = DEFAULT_INDEX_TYPE, + graph: Optional[Neo4jGraph] = None, ) -> None: try: import neo4j @@ -501,40 +503,44 @@ class Neo4jVector(VectorStore): "distance_strategy must be either 'EUCLIDEAN_DISTANCE' or 'COSINE'" ) - # Handle if the credentials are environment variables + # Graph object takes precedent over env or input params + if graph: + self._driver = graph._driver + self._database = graph._database + else: + # Handle if the credentials are environment variables + # Support URL for backwards compatibility + if not url: + url = os.environ.get("NEO4J_URL") - # Support URL for backwards compatibility - if not url: - url = os.environ.get("NEO4J_URL") + url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI") + username = get_from_dict_or_env( + {"username": username}, "username", "NEO4J_USERNAME" + ) + password = get_from_dict_or_env( + {"password": password}, "password", "NEO4J_PASSWORD" + ) + database = get_from_dict_or_env( + {"database": database}, "database", "NEO4J_DATABASE", "neo4j" + ) - url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI") - username = get_from_dict_or_env( - {"username": username}, "username", "NEO4J_USERNAME" - ) - password = get_from_dict_or_env( - {"password": password}, "password", "NEO4J_PASSWORD" - ) - database = get_from_dict_or_env( - {"database": database}, "database", "NEO4J_DATABASE", "neo4j" - ) + self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) + self._database = database + # Verify connection + try: + self._driver.verify_connectivity() + except neo4j.exceptions.ServiceUnavailable: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the url is correct" + ) + except neo4j.exceptions.AuthError: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the username and password are correct" + ) - self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) - self._database = database self.schema = "" - # Verify connection - try: - self._driver.verify_connectivity() - except neo4j.exceptions.ServiceUnavailable: - raise ValueError( - "Could not connect to Neo4j database. " - "Please ensure that the url is correct" - ) - except neo4j.exceptions.AuthError: - raise ValueError( - "Could not connect to Neo4j database. " - "Please ensure that the username and password are correct" - ) - # Verify if the version support vector index self._is_enterprise = False self.verify_version() diff --git a/libs/community/tests/integration_tests/chat_message_histories/test_neo4j.py b/libs/community/tests/integration_tests/chat_message_histories/test_neo4j.py new file mode 100644 index 00000000000..5ab1af54607 --- /dev/null +++ b/libs/community/tests/integration_tests/chat_message_histories/test_neo4j.py @@ -0,0 +1,66 @@ +import os + +from langchain_core.messages import AIMessage, HumanMessage + +from langchain_community.chat_message_histories import Neo4jChatMessageHistory +from langchain_community.graphs import Neo4jGraph + + +def test_add_messages() -> None: + """Basic testing: adding messages to the Neo4jChatMessageHistory.""" + assert os.environ.get("NEO4J_URI") is not None + assert os.environ.get("NEO4J_USERNAME") is not None + assert os.environ.get("NEO4J_PASSWORD") is not None + message_store = Neo4jChatMessageHistory("23334") + message_store.clear() + assert len(message_store.messages) == 0 + message_store.add_user_message("Hello! Language Chain!") + message_store.add_ai_message("Hi Guys!") + + # create another message store to check if the messages are stored correctly + message_store_another = Neo4jChatMessageHistory("46666") + message_store_another.clear() + assert len(message_store_another.messages) == 0 + message_store_another.add_user_message("Hello! Bot!") + message_store_another.add_ai_message("Hi there!") + message_store_another.add_user_message("How's this pr going?") + + # Now check if the messages are stored in the database correctly + assert len(message_store.messages) == 2 + assert isinstance(message_store.messages[0], HumanMessage) + assert isinstance(message_store.messages[1], AIMessage) + assert message_store.messages[0].content == "Hello! Language Chain!" + assert message_store.messages[1].content == "Hi Guys!" + + assert len(message_store_another.messages) == 3 + assert isinstance(message_store_another.messages[0], HumanMessage) + assert isinstance(message_store_another.messages[1], AIMessage) + assert isinstance(message_store_another.messages[2], HumanMessage) + assert message_store_another.messages[0].content == "Hello! Bot!" + assert message_store_another.messages[1].content == "Hi there!" + assert message_store_another.messages[2].content == "How's this pr going?" + + # Now clear the first history + message_store.clear() + assert len(message_store.messages) == 0 + assert len(message_store_another.messages) == 3 + message_store_another.clear() + assert len(message_store.messages) == 0 + assert len(message_store_another.messages) == 0 + + +def test_add_messages_graph_object() -> None: + """Basic testing: Passing driver through graph object.""" + assert os.environ.get("NEO4J_URI") is not None + assert os.environ.get("NEO4J_USERNAME") is not None + assert os.environ.get("NEO4J_PASSWORD") is not None + graph = Neo4jGraph() + # rewrite env for testing + os.environ["NEO4J_USERNAME"] = "foo" + message_store = Neo4jChatMessageHistory("23334", graph=graph) + message_store.clear() + assert len(message_store.messages) == 0 + message_store.add_user_message("Hello! Language Chain!") + message_store.add_ai_message("Hi Guys!") + # Now check if the messages are stored in the database correctly + assert len(message_store.messages) == 2 diff --git a/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py b/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py index 87db0737705..61c5e1117c2 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py +++ b/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, cast from langchain_core.documents import Document +from langchain_community.graphs import Neo4jGraph from langchain_community.vectorstores.neo4j_vector import ( Neo4jVector, SearchType, @@ -902,3 +903,20 @@ OPTIONS {indexConfig: { assert output == [Document(page_content="foo-text", metadata={"foo": "bar"})] drop_vector_indexes(docsearch) + + +def test_neo4jvector_passing_graph_object() -> None: + """Test end to end construction and search with passing graph object.""" + graph = Neo4jGraph() + # Rewrite env vars to make sure it fails if env is used + os.environ["NEO4J_URI"] = "foo" + docsearch = Neo4jVector.from_texts( + texts=texts, + embedding=FakeEmbeddingsWithOsDimension(), + graph=graph, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + drop_vector_indexes(docsearch)