mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
community[patch]: Support passing graph object to Neo4j integrations (#20876)
For driver connection reusage, we introduce passing the graph object to neo4j integrations
This commit is contained in:
parent
748a6ae609
commit
520972fd0f
@ -2,7 +2,9 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
from langchain_core.messages import BaseMessage, messages_from_dict
|
from langchain_core.messages import BaseMessage, messages_from_dict
|
||||||
from langchain_core.utils import get_from_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
from langchain_community.graphs import Neo4jGraph
|
||||||
|
|
||||||
|
|
||||||
class Neo4jChatMessageHistory(BaseChatMessageHistory):
|
class Neo4jChatMessageHistory(BaseChatMessageHistory):
|
||||||
@ -17,6 +19,8 @@ class Neo4jChatMessageHistory(BaseChatMessageHistory):
|
|||||||
database: str = "neo4j",
|
database: str = "neo4j",
|
||||||
node_label: str = "Session",
|
node_label: str = "Session",
|
||||||
window: int = 3,
|
window: int = 3,
|
||||||
|
*,
|
||||||
|
graph: Optional[Neo4jGraph] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import neo4j
|
import neo4j
|
||||||
@ -30,30 +34,41 @@ class Neo4jChatMessageHistory(BaseChatMessageHistory):
|
|||||||
if not session_id:
|
if not session_id:
|
||||||
raise ValueError("Please ensure that the session_id parameter is provided")
|
raise ValueError("Please ensure that the session_id parameter is provided")
|
||||||
|
|
||||||
url = get_from_env("url", "NEO4J_URI", url)
|
# Graph object takes precedent over env or input params
|
||||||
username = get_from_env("username", "NEO4J_USERNAME", username)
|
if graph:
|
||||||
password = get_from_env("password", "NEO4J_PASSWORD", password)
|
self._driver = graph._driver
|
||||||
database = get_from_env("database", "NEO4J_DATABASE", database)
|
self._database = graph._database
|
||||||
|
else:
|
||||||
|
# Handle if the credentials are environment variables
|
||||||
|
url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI")
|
||||||
|
username = get_from_dict_or_env(
|
||||||
|
{"username": username}, "username", "NEO4J_USERNAME"
|
||||||
|
)
|
||||||
|
password = get_from_dict_or_env(
|
||||||
|
{"password": password}, "password", "NEO4J_PASSWORD"
|
||||||
|
)
|
||||||
|
database = get_from_dict_or_env(
|
||||||
|
{"database": database}, "database", "NEO4J_DATABASE", "neo4j"
|
||||||
|
)
|
||||||
|
|
||||||
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
||||||
self._database = database
|
self._database = database
|
||||||
|
# Verify connection
|
||||||
|
try:
|
||||||
|
self._driver.verify_connectivity()
|
||||||
|
except neo4j.exceptions.ServiceUnavailable:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not connect to Neo4j database. "
|
||||||
|
"Please ensure that the url is correct"
|
||||||
|
)
|
||||||
|
except neo4j.exceptions.AuthError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not connect to Neo4j database. "
|
||||||
|
"Please ensure that the username and password are correct"
|
||||||
|
)
|
||||||
self._session_id = session_id
|
self._session_id = session_id
|
||||||
self._node_label = node_label
|
self._node_label = node_label
|
||||||
self._window = window
|
self._window = window
|
||||||
|
|
||||||
# Verify connection
|
|
||||||
try:
|
|
||||||
self._driver.verify_connectivity()
|
|
||||||
except neo4j.exceptions.ServiceUnavailable:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not connect to Neo4j database. "
|
|
||||||
"Please ensure that the url is correct"
|
|
||||||
)
|
|
||||||
except neo4j.exceptions.AuthError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not connect to Neo4j database. "
|
|
||||||
"Please ensure that the username and password are correct"
|
|
||||||
)
|
|
||||||
# Create session node
|
# Create session node
|
||||||
self._driver.execute_query(
|
self._driver.execute_query(
|
||||||
f"MERGE (s:`{self._node_label}` {{id:$session_id}})",
|
f"MERGE (s:`{self._node_label}` {{id:$session_id}})",
|
||||||
|
@ -20,6 +20,7 @@ from langchain_core.embeddings import Embeddings
|
|||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
|
from langchain_community.graphs import Neo4jGraph
|
||||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||||
|
|
||||||
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
|
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
|
||||||
@ -483,6 +484,7 @@ class Neo4jVector(VectorStore):
|
|||||||
retrieval_query: str = "",
|
retrieval_query: str = "",
|
||||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||||
index_type: IndexType = DEFAULT_INDEX_TYPE,
|
index_type: IndexType = DEFAULT_INDEX_TYPE,
|
||||||
|
graph: Optional[Neo4jGraph] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
import neo4j
|
import neo4j
|
||||||
@ -501,40 +503,44 @@ class Neo4jVector(VectorStore):
|
|||||||
"distance_strategy must be either 'EUCLIDEAN_DISTANCE' or 'COSINE'"
|
"distance_strategy must be either 'EUCLIDEAN_DISTANCE' or 'COSINE'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle if the credentials are environment variables
|
# Graph object takes precedent over env or input params
|
||||||
|
if graph:
|
||||||
|
self._driver = graph._driver
|
||||||
|
self._database = graph._database
|
||||||
|
else:
|
||||||
|
# Handle if the credentials are environment variables
|
||||||
|
# Support URL for backwards compatibility
|
||||||
|
if not url:
|
||||||
|
url = os.environ.get("NEO4J_URL")
|
||||||
|
|
||||||
# Support URL for backwards compatibility
|
url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI")
|
||||||
if not url:
|
username = get_from_dict_or_env(
|
||||||
url = os.environ.get("NEO4J_URL")
|
{"username": username}, "username", "NEO4J_USERNAME"
|
||||||
|
)
|
||||||
|
password = get_from_dict_or_env(
|
||||||
|
{"password": password}, "password", "NEO4J_PASSWORD"
|
||||||
|
)
|
||||||
|
database = get_from_dict_or_env(
|
||||||
|
{"database": database}, "database", "NEO4J_DATABASE", "neo4j"
|
||||||
|
)
|
||||||
|
|
||||||
url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI")
|
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
||||||
username = get_from_dict_or_env(
|
self._database = database
|
||||||
{"username": username}, "username", "NEO4J_USERNAME"
|
# Verify connection
|
||||||
)
|
try:
|
||||||
password = get_from_dict_or_env(
|
self._driver.verify_connectivity()
|
||||||
{"password": password}, "password", "NEO4J_PASSWORD"
|
except neo4j.exceptions.ServiceUnavailable:
|
||||||
)
|
raise ValueError(
|
||||||
database = get_from_dict_or_env(
|
"Could not connect to Neo4j database. "
|
||||||
{"database": database}, "database", "NEO4J_DATABASE", "neo4j"
|
"Please ensure that the url is correct"
|
||||||
)
|
)
|
||||||
|
except neo4j.exceptions.AuthError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not connect to Neo4j database. "
|
||||||
|
"Please ensure that the username and password are correct"
|
||||||
|
)
|
||||||
|
|
||||||
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
|
||||||
self._database = database
|
|
||||||
self.schema = ""
|
self.schema = ""
|
||||||
# Verify connection
|
|
||||||
try:
|
|
||||||
self._driver.verify_connectivity()
|
|
||||||
except neo4j.exceptions.ServiceUnavailable:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not connect to Neo4j database. "
|
|
||||||
"Please ensure that the url is correct"
|
|
||||||
)
|
|
||||||
except neo4j.exceptions.AuthError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not connect to Neo4j database. "
|
|
||||||
"Please ensure that the username and password are correct"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify if the version support vector index
|
# Verify if the version support vector index
|
||||||
self._is_enterprise = False
|
self._is_enterprise = False
|
||||||
self.verify_version()
|
self.verify_version()
|
||||||
|
@ -0,0 +1,66 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
|
||||||
|
from langchain_community.graphs import Neo4jGraph
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_messages() -> None:
|
||||||
|
"""Basic testing: adding messages to the Neo4jChatMessageHistory."""
|
||||||
|
assert os.environ.get("NEO4J_URI") is not None
|
||||||
|
assert os.environ.get("NEO4J_USERNAME") is not None
|
||||||
|
assert os.environ.get("NEO4J_PASSWORD") is not None
|
||||||
|
message_store = Neo4jChatMessageHistory("23334")
|
||||||
|
message_store.clear()
|
||||||
|
assert len(message_store.messages) == 0
|
||||||
|
message_store.add_user_message("Hello! Language Chain!")
|
||||||
|
message_store.add_ai_message("Hi Guys!")
|
||||||
|
|
||||||
|
# create another message store to check if the messages are stored correctly
|
||||||
|
message_store_another = Neo4jChatMessageHistory("46666")
|
||||||
|
message_store_another.clear()
|
||||||
|
assert len(message_store_another.messages) == 0
|
||||||
|
message_store_another.add_user_message("Hello! Bot!")
|
||||||
|
message_store_another.add_ai_message("Hi there!")
|
||||||
|
message_store_another.add_user_message("How's this pr going?")
|
||||||
|
|
||||||
|
# Now check if the messages are stored in the database correctly
|
||||||
|
assert len(message_store.messages) == 2
|
||||||
|
assert isinstance(message_store.messages[0], HumanMessage)
|
||||||
|
assert isinstance(message_store.messages[1], AIMessage)
|
||||||
|
assert message_store.messages[0].content == "Hello! Language Chain!"
|
||||||
|
assert message_store.messages[1].content == "Hi Guys!"
|
||||||
|
|
||||||
|
assert len(message_store_another.messages) == 3
|
||||||
|
assert isinstance(message_store_another.messages[0], HumanMessage)
|
||||||
|
assert isinstance(message_store_another.messages[1], AIMessage)
|
||||||
|
assert isinstance(message_store_another.messages[2], HumanMessage)
|
||||||
|
assert message_store_another.messages[0].content == "Hello! Bot!"
|
||||||
|
assert message_store_another.messages[1].content == "Hi there!"
|
||||||
|
assert message_store_another.messages[2].content == "How's this pr going?"
|
||||||
|
|
||||||
|
# Now clear the first history
|
||||||
|
message_store.clear()
|
||||||
|
assert len(message_store.messages) == 0
|
||||||
|
assert len(message_store_another.messages) == 3
|
||||||
|
message_store_another.clear()
|
||||||
|
assert len(message_store.messages) == 0
|
||||||
|
assert len(message_store_another.messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_messages_graph_object() -> None:
|
||||||
|
"""Basic testing: Passing driver through graph object."""
|
||||||
|
assert os.environ.get("NEO4J_URI") is not None
|
||||||
|
assert os.environ.get("NEO4J_USERNAME") is not None
|
||||||
|
assert os.environ.get("NEO4J_PASSWORD") is not None
|
||||||
|
graph = Neo4jGraph()
|
||||||
|
# rewrite env for testing
|
||||||
|
os.environ["NEO4J_USERNAME"] = "foo"
|
||||||
|
message_store = Neo4jChatMessageHistory("23334", graph=graph)
|
||||||
|
message_store.clear()
|
||||||
|
assert len(message_store.messages) == 0
|
||||||
|
message_store.add_user_message("Hello! Language Chain!")
|
||||||
|
message_store.add_ai_message("Hi Guys!")
|
||||||
|
# Now check if the messages are stored in the database correctly
|
||||||
|
assert len(message_store.messages) == 2
|
@ -4,6 +4,7 @@ from typing import Any, Dict, List, cast
|
|||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from langchain_community.graphs import Neo4jGraph
|
||||||
from langchain_community.vectorstores.neo4j_vector import (
|
from langchain_community.vectorstores.neo4j_vector import (
|
||||||
Neo4jVector,
|
Neo4jVector,
|
||||||
SearchType,
|
SearchType,
|
||||||
@ -902,3 +903,20 @@ OPTIONS {indexConfig: {
|
|||||||
assert output == [Document(page_content="foo-text", metadata={"foo": "bar"})]
|
assert output == [Document(page_content="foo-text", metadata={"foo": "bar"})]
|
||||||
|
|
||||||
drop_vector_indexes(docsearch)
|
drop_vector_indexes(docsearch)
|
||||||
|
|
||||||
|
|
||||||
|
def test_neo4jvector_passing_graph_object() -> None:
|
||||||
|
"""Test end to end construction and search with passing graph object."""
|
||||||
|
graph = Neo4jGraph()
|
||||||
|
# Rewrite env vars to make sure it fails if env is used
|
||||||
|
os.environ["NEO4J_URI"] = "foo"
|
||||||
|
docsearch = Neo4jVector.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
embedding=FakeEmbeddingsWithOsDimension(),
|
||||||
|
graph=graph,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
drop_vector_indexes(docsearch)
|
||||||
|
Loading…
Reference in New Issue
Block a user