mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
Add neo4j graph environment variables (#12080)
This commit is contained in:
parent
d5400f6502
commit
82f4c0589c
@ -1,7 +1,8 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.graphs.graph_document import GraphDocument
|
from langchain.graphs.graph_document import GraphDocument
|
||||||
from langchain.graphs.graph_store import GraphStore
|
from langchain.graphs.graph_store import GraphStore
|
||||||
|
from langchain.utils import get_from_env
|
||||||
|
|
||||||
node_properties_query = """
|
node_properties_query = """
|
||||||
CALL apoc.meta.data()
|
CALL apoc.meta.data()
|
||||||
@ -45,7 +46,11 @@ class Neo4jGraph(GraphStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, url: str, username: str, password: str, database: str = "neo4j"
|
self,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
database: str = "neo4j",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a new Neo4j graph wrapper instance."""
|
"""Create a new Neo4j graph wrapper instance."""
|
||||||
try:
|
try:
|
||||||
@ -56,6 +61,11 @@ class Neo4jGraph(GraphStore):
|
|||||||
"Please install it with `pip install neo4j`."
|
"Please install it with `pip install neo4j`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
url = get_from_env("url", "NEO4J_URI", url)
|
||||||
|
username = get_from_env("username", "NEO4J_USERNAME", username)
|
||||||
|
password = get_from_env("password", "NEO4J_PASSWORD", password)
|
||||||
|
database = get_from_env("database", "NEO4J_DATABASE", database)
|
||||||
|
|
||||||
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
||||||
self._database = database
|
self._database = database
|
||||||
self.schema: str = ""
|
self.schema: str = ""
|
||||||
|
@ -14,7 +14,7 @@ from langchain.llms.openai import OpenAI
|
|||||||
|
|
||||||
def test_connect_neo4j() -> None:
|
def test_connect_neo4j() -> None:
|
||||||
"""Test that Neo4j database is correctly instantiated and connected."""
|
"""Test that Neo4j database is correctly instantiated and connected."""
|
||||||
url = os.environ.get("NEO4J_URL")
|
url = os.environ.get("NEO4J_URI")
|
||||||
username = os.environ.get("NEO4J_USERNAME")
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
password = os.environ.get("NEO4J_PASSWORD")
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
@ -36,9 +36,22 @@ def test_connect_neo4j() -> None:
|
|||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_neo4j_env() -> None:
|
||||||
|
"""Test that Neo4j database environment variables."""
|
||||||
|
graph = Neo4jGraph()
|
||||||
|
|
||||||
|
output = graph.query(
|
||||||
|
"""
|
||||||
|
RETURN "test" AS output
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
expected_output = [{"output": "test"}]
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_cypher_generating_run() -> None:
|
def test_cypher_generating_run() -> None:
|
||||||
"""Test that Cypher statement is correctly generated and executed."""
|
"""Test that Cypher statement is correctly generated and executed."""
|
||||||
url = os.environ.get("NEO4J_URL")
|
url = os.environ.get("NEO4J_URI")
|
||||||
username = os.environ.get("NEO4J_USERNAME")
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
password = os.environ.get("NEO4J_PASSWORD")
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
@ -68,7 +81,7 @@ def test_cypher_generating_run() -> None:
|
|||||||
|
|
||||||
def test_cypher_top_k() -> None:
|
def test_cypher_top_k() -> None:
|
||||||
"""Test top_k parameter correctly limits the number of results in the context."""
|
"""Test top_k parameter correctly limits the number of results in the context."""
|
||||||
url = os.environ.get("NEO4J_URL")
|
url = os.environ.get("NEO4J_URI")
|
||||||
username = os.environ.get("NEO4J_USERNAME")
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
password = os.environ.get("NEO4J_PASSWORD")
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
@ -102,7 +115,7 @@ def test_cypher_top_k() -> None:
|
|||||||
|
|
||||||
def test_cypher_intermediate_steps() -> None:
|
def test_cypher_intermediate_steps() -> None:
|
||||||
"""Test the returning of the intermediate steps."""
|
"""Test the returning of the intermediate steps."""
|
||||||
url = os.environ.get("NEO4J_URL")
|
url = os.environ.get("NEO4J_URI")
|
||||||
username = os.environ.get("NEO4J_USERNAME")
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
password = os.environ.get("NEO4J_PASSWORD")
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
@ -146,7 +159,7 @@ def test_cypher_intermediate_steps() -> None:
|
|||||||
|
|
||||||
def test_cypher_return_direct() -> None:
|
def test_cypher_return_direct() -> None:
|
||||||
"""Test that chain returns direct results."""
|
"""Test that chain returns direct results."""
|
||||||
url = os.environ.get("NEO4J_URL")
|
url = os.environ.get("NEO4J_URI")
|
||||||
username = os.environ.get("NEO4J_USERNAME")
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
password = os.environ.get("NEO4J_PASSWORD")
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
@ -178,7 +191,7 @@ def test_cypher_return_direct() -> None:
|
|||||||
|
|
||||||
def test_cypher_return_correct_schema() -> None:
|
def test_cypher_return_correct_schema() -> None:
|
||||||
"""Test that chain returns direct results."""
|
"""Test that chain returns direct results."""
|
||||||
url = os.environ.get("NEO4J_URL")
|
url = os.environ.get("NEO4J_URI")
|
||||||
username = os.environ.get("NEO4J_USERNAME")
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
password = os.environ.get("NEO4J_PASSWORD")
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
@ -243,7 +256,7 @@ def test_cypher_save_load() -> None:
|
|||||||
"""Test saving and loading."""
|
"""Test saving and loading."""
|
||||||
|
|
||||||
FILE_PATH = "cypher.yaml"
|
FILE_PATH = "cypher.yaml"
|
||||||
url = os.environ.get("NEO4J_URL")
|
url = os.environ.get("NEO4J_URI")
|
||||||
username = os.environ.get("NEO4J_USERNAME")
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
password = os.environ.get("NEO4J_PASSWORD")
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
@ -267,7 +280,7 @@ def test_cypher_save_load() -> None:
|
|||||||
|
|
||||||
def test_exclude_types() -> None:
|
def test_exclude_types() -> None:
|
||||||
"""Test exclude types from schema."""
|
"""Test exclude types from schema."""
|
||||||
url = os.environ.get("NEO4J_URL")
|
url = os.environ.get("NEO4J_URI")
|
||||||
username = os.environ.get("NEO4J_USERNAME")
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
password = os.environ.get("NEO4J_PASSWORD")
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
@ -307,7 +320,7 @@ def test_exclude_types() -> None:
|
|||||||
|
|
||||||
def test_include_types() -> None:
|
def test_include_types() -> None:
|
||||||
"""Test include types from schema."""
|
"""Test include types from schema."""
|
||||||
url = os.environ.get("NEO4J_URL")
|
url = os.environ.get("NEO4J_URI")
|
||||||
username = os.environ.get("NEO4J_USERNAME")
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
password = os.environ.get("NEO4J_PASSWORD")
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
@ -347,7 +360,7 @@ def test_include_types() -> None:
|
|||||||
|
|
||||||
def test_include_types2() -> None:
|
def test_include_types2() -> None:
|
||||||
"""Test include types from schema."""
|
"""Test include types from schema."""
|
||||||
url = os.environ.get("NEO4J_URL")
|
url = os.environ.get("NEO4J_URI")
|
||||||
username = os.environ.get("NEO4J_USERNAME")
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
password = os.environ.get("NEO4J_PASSWORD")
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user