diff --git a/langchain/chains/graph_qa/cypher.py b/langchain/chains/graph_qa/cypher.py index 68966cd277d..3ccdcb93fd0 100644 --- a/langchain/chains/graph_qa/cypher.py +++ b/langchain/chains/graph_qa/cypher.py @@ -47,11 +47,16 @@ class GraphCypherQAChain(Chain): *, qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT, + qa_chain: LLMChain = None, + cypher_generation_chain: LLMChain = None, **kwargs: Any, ) -> GraphCypherQAChain: """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt) + if qa_chain is None: + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + + if cypher_generation_chain is None: + cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt) return cls( qa_chain=qa_chain, diff --git a/langchain/graphs/neo4j_graph.py b/langchain/graphs/neo4j_graph.py index 4433771a255..a077a0f469d 100644 --- a/langchain/graphs/neo4j_graph.py +++ b/langchain/graphs/neo4j_graph.py @@ -43,27 +43,7 @@ class Neo4jGraph: self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) self._database = database self.schema = "" - # Verify connection - try: - self._driver.verify_connectivity() - except neo4j.exceptions.ServiceUnavailable: - raise ValueError( - "Could not connect to Neo4j database. " - "Please ensure that the url is correct" - ) - except neo4j.exceptions.AuthError: - raise ValueError( - "Could not connect to Neo4j database. " - "Please ensure that the username and password are correct" - ) - # Set schema - try: - self.refresh_schema() - except neo4j.exceptions.ClientError: - raise ValueError( - "Could not use APOC procedures. " - "Please install the APOC plugin in Neo4j." - ) + self.refresh_schema() @property def get_schema(self) -> str: