mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
Add neo4j graph environment variables (#12080)
This commit is contained in:
parent
d5400f6502
commit
82f4c0589c
@ -1,7 +1,8 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.graphs.graph_document import GraphDocument
|
||||
from langchain.graphs.graph_store import GraphStore
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
node_properties_query = """
|
||||
CALL apoc.meta.data()
|
||||
@ -45,7 +46,11 @@ class Neo4jGraph(GraphStore):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, url: str, username: str, password: str, database: str = "neo4j"
|
||||
self,
|
||||
url: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
database: str = "neo4j",
|
||||
) -> None:
|
||||
"""Create a new Neo4j graph wrapper instance."""
|
||||
try:
|
||||
@ -56,6 +61,11 @@ class Neo4jGraph(GraphStore):
|
||||
"Please install it with `pip install neo4j`."
|
||||
)
|
||||
|
||||
url = get_from_env("url", "NEO4J_URI", url)
|
||||
username = get_from_env("username", "NEO4J_USERNAME", username)
|
||||
password = get_from_env("password", "NEO4J_PASSWORD", password)
|
||||
database = get_from_env("database", "NEO4J_DATABASE", database)
|
||||
|
||||
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
||||
self._database = database
|
||||
self.schema: str = ""
|
||||
|
@ -14,7 +14,7 @@ from langchain.llms.openai import OpenAI
|
||||
|
||||
def test_connect_neo4j() -> None:
|
||||
"""Test that Neo4j database is correctly instantiated and connected."""
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
@ -36,9 +36,22 @@ def test_connect_neo4j() -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_connect_neo4j_env() -> None:
|
||||
"""Test that Neo4j database environment variables."""
|
||||
graph = Neo4jGraph()
|
||||
|
||||
output = graph.query(
|
||||
"""
|
||||
RETURN "test" AS output
|
||||
"""
|
||||
)
|
||||
expected_output = [{"output": "test"}]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_cypher_generating_run() -> None:
|
||||
"""Test that Cypher statement is correctly generated and executed."""
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
@ -68,7 +81,7 @@ def test_cypher_generating_run() -> None:
|
||||
|
||||
def test_cypher_top_k() -> None:
|
||||
"""Test top_k parameter correctly limits the number of results in the context."""
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
@ -102,7 +115,7 @@ def test_cypher_top_k() -> None:
|
||||
|
||||
def test_cypher_intermediate_steps() -> None:
|
||||
"""Test the returning of the intermediate steps."""
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
@ -146,7 +159,7 @@ def test_cypher_intermediate_steps() -> None:
|
||||
|
||||
def test_cypher_return_direct() -> None:
|
||||
"""Test that chain returns direct results."""
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
@ -178,7 +191,7 @@ def test_cypher_return_direct() -> None:
|
||||
|
||||
def test_cypher_return_correct_schema() -> None:
|
||||
"""Test that chain returns direct results."""
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
@ -243,7 +256,7 @@ def test_cypher_save_load() -> None:
|
||||
"""Test saving and loading."""
|
||||
|
||||
FILE_PATH = "cypher.yaml"
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
@ -267,7 +280,7 @@ def test_cypher_save_load() -> None:
|
||||
|
||||
def test_exclude_types() -> None:
|
||||
"""Test exclude types from schema."""
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
@ -307,7 +320,7 @@ def test_exclude_types() -> None:
|
||||
|
||||
def test_include_types() -> None:
|
||||
"""Test include types from schema."""
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
@ -347,7 +360,7 @@ def test_include_types() -> None:
|
||||
|
||||
def test_include_types2() -> None:
|
||||
"""Test include types from schema."""
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
|
Loading…
Reference in New Issue
Block a user