mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
community[patch]: Support passing graph object to Neo4j integrations (#20876)
For driver connection reusage, we introduce passing the graph object to neo4j integrations
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
import os
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
|
||||
|
||||
def test_add_messages() -> None:
|
||||
"""Basic testing: adding messages to the Neo4jChatMessageHistory."""
|
||||
assert os.environ.get("NEO4J_URI") is not None
|
||||
assert os.environ.get("NEO4J_USERNAME") is not None
|
||||
assert os.environ.get("NEO4J_PASSWORD") is not None
|
||||
message_store = Neo4jChatMessageHistory("23334")
|
||||
message_store.clear()
|
||||
assert len(message_store.messages) == 0
|
||||
message_store.add_user_message("Hello! Language Chain!")
|
||||
message_store.add_ai_message("Hi Guys!")
|
||||
|
||||
# create another message store to check if the messages are stored correctly
|
||||
message_store_another = Neo4jChatMessageHistory("46666")
|
||||
message_store_another.clear()
|
||||
assert len(message_store_another.messages) == 0
|
||||
message_store_another.add_user_message("Hello! Bot!")
|
||||
message_store_another.add_ai_message("Hi there!")
|
||||
message_store_another.add_user_message("How's this pr going?")
|
||||
|
||||
# Now check if the messages are stored in the database correctly
|
||||
assert len(message_store.messages) == 2
|
||||
assert isinstance(message_store.messages[0], HumanMessage)
|
||||
assert isinstance(message_store.messages[1], AIMessage)
|
||||
assert message_store.messages[0].content == "Hello! Language Chain!"
|
||||
assert message_store.messages[1].content == "Hi Guys!"
|
||||
|
||||
assert len(message_store_another.messages) == 3
|
||||
assert isinstance(message_store_another.messages[0], HumanMessage)
|
||||
assert isinstance(message_store_another.messages[1], AIMessage)
|
||||
assert isinstance(message_store_another.messages[2], HumanMessage)
|
||||
assert message_store_another.messages[0].content == "Hello! Bot!"
|
||||
assert message_store_another.messages[1].content == "Hi there!"
|
||||
assert message_store_another.messages[2].content == "How's this pr going?"
|
||||
|
||||
# Now clear the first history
|
||||
message_store.clear()
|
||||
assert len(message_store.messages) == 0
|
||||
assert len(message_store_another.messages) == 3
|
||||
message_store_another.clear()
|
||||
assert len(message_store.messages) == 0
|
||||
assert len(message_store_another.messages) == 0
|
||||
|
||||
|
||||
def test_add_messages_graph_object() -> None:
|
||||
"""Basic testing: Passing driver through graph object."""
|
||||
assert os.environ.get("NEO4J_URI") is not None
|
||||
assert os.environ.get("NEO4J_USERNAME") is not None
|
||||
assert os.environ.get("NEO4J_PASSWORD") is not None
|
||||
graph = Neo4jGraph()
|
||||
# rewrite env for testing
|
||||
os.environ["NEO4J_USERNAME"] = "foo"
|
||||
message_store = Neo4jChatMessageHistory("23334", graph=graph)
|
||||
message_store.clear()
|
||||
assert len(message_store.messages) == 0
|
||||
message_store.add_user_message("Hello! Language Chain!")
|
||||
message_store.add_ai_message("Hi Guys!")
|
||||
# Now check if the messages are stored in the database correctly
|
||||
assert len(message_store.messages) == 2
|
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_community.vectorstores.neo4j_vector import (
|
||||
Neo4jVector,
|
||||
SearchType,
|
||||
@@ -902,3 +903,20 @@ OPTIONS {indexConfig: {
|
||||
assert output == [Document(page_content="foo-text", metadata={"foo": "bar"})]
|
||||
|
||||
drop_vector_indexes(docsearch)
|
||||
|
||||
|
||||
def test_neo4jvector_passing_graph_object() -> None:
|
||||
"""Test end to end construction and search with passing graph object."""
|
||||
graph = Neo4jGraph()
|
||||
# Rewrite env vars to make sure it fails if env is used
|
||||
os.environ["NEO4J_URI"] = "foo"
|
||||
docsearch = Neo4jVector.from_texts(
|
||||
texts=texts,
|
||||
embedding=FakeEmbeddingsWithOsDimension(),
|
||||
graph=graph,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
drop_vector_indexes(docsearch)
|
||||
|
Reference in New Issue
Block a user