mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 18:33:40 +00:00
Add additional parameters to Graph Cypher Chain (#5979)
Based on the inspiration from the SQL chain, the following three parameters are added to Graph Cypher Chain. - top_k: Limited the number of results from the database to be used as context - return_direct: Return database results without transforming them to natural language - return_intermediate_steps: Return intermediate steps
This commit is contained in:
parent
0ca37e613c
commit
d5819a7ca7
@ -177,7 +177,7 @@
|
|||||||
"\u001b[32;1m\u001b[1;3mMATCH (a:Actor)-[:ACTED_IN]->(m:Movie {name: 'Top Gun'})\n",
|
"\u001b[32;1m\u001b[1;3mMATCH (a:Actor)-[:ACTED_IN]->(m:Movie {name: 'Top Gun'})\n",
|
||||||
"RETURN a.name\u001b[0m\n",
|
"RETURN a.name\u001b[0m\n",
|
||||||
"Full Context:\n",
|
"Full Context:\n",
|
||||||
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n",
|
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Tom Cruise'}]\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -185,7 +185,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan played in Top Gun.'"
|
"'Val Kilmer, Anthony Edwards, Meg Ryan, and Tom Cruise played in Top Gun.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 7,
|
"execution_count": 7,
|
||||||
@ -197,10 +197,180 @@
|
|||||||
"chain.run(\"Who played in Top Gun?\")"
|
"chain.run(\"Who played in Top Gun?\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2d28c4df",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Limit the number of results\n",
|
||||||
|
"You can limit the number of results from the Cypher QA Chain using the `top_k` parameter.\n",
|
||||||
|
"The default is 10."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "df230946",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = GraphCypherQAChain.from_llm(\n",
|
||||||
|
" ChatOpenAI(temperature=0), graph=graph, verbose=True, top_k=2\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "3f1600ee",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
|
||||||
|
"Generated Cypher:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mMATCH (a:Actor)-[:ACTED_IN]->(m:Movie {name: 'Top Gun'})\n",
|
||||||
|
"RETURN a.name\u001b[0m\n",
|
||||||
|
"Full Context:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}]\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'Val Kilmer and Anthony Edwards played in Top Gun.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.run(\"Who played in Top Gun?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "88c16206",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Return intermediate results\n",
|
||||||
|
"You can return intermediate steps from the Cypher QA Chain using the `return_intermediate_steps` parameter"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "e412f36b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = GraphCypherQAChain.from_llm(\n",
|
||||||
|
" ChatOpenAI(temperature=0), graph=graph, verbose=True, return_intermediate_steps=True\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "4f4699dc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
|
||||||
|
"Generated Cypher:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mMATCH (a:Actor)-[:ACTED_IN]->(m:Movie {name: 'Top Gun'})\n",
|
||||||
|
"RETURN a.name\u001b[0m\n",
|
||||||
|
"Full Context:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Tom Cruise'}]\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"Intermediate steps: [{'query': \"MATCH (a:Actor)-[:ACTED_IN]->(m:Movie {name: 'Top Gun'})\\nRETURN a.name\"}, {'context': [{'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Tom Cruise'}]}]\n",
|
||||||
|
"Final answer: Val Kilmer, Anthony Edwards, Meg Ryan, and Tom Cruise played in Top Gun.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"result = chain(\"Who played in Top Gun?\")\n",
|
||||||
|
"print(f\"Intermediate steps: {result['intermediate_steps']}\")\n",
|
||||||
|
"print(f\"Final answer: {result['result']}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d6e1b054",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Return direct results\n",
|
||||||
|
"You can return direct results from the Cypher QA Chain using the `return_direct` parameter"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "2d3acf10",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = GraphCypherQAChain.from_llm(\n",
|
||||||
|
" ChatOpenAI(temperature=0), graph=graph, verbose=True, return_direct=True\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"id": "b0a9d143",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
|
||||||
|
"Generated Cypher:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mMATCH (a:Actor)-[:ACTED_IN]->(m:Movie {name: 'Top Gun'})\n",
|
||||||
|
"RETURN a.name\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[{'a.name': 'Val Kilmer'},\n",
|
||||||
|
" {'a.name': 'Anthony Edwards'},\n",
|
||||||
|
" {'a.name': 'Meg Ryan'},\n",
|
||||||
|
" {'a.name': 'Tom Cruise'}]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.run(\"Who played in Top Gun?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "b4825316",
|
"id": "74d0a36f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
@ -222,7 +392,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.1"
|
"version": "3.8.8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -14,6 +14,8 @@ from langchain.chains.llm import LLMChain
|
|||||||
from langchain.graphs.neo4j_graph import Neo4jGraph
|
from langchain.graphs.neo4j_graph import Neo4jGraph
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||||
|
|
||||||
|
|
||||||
def extract_cypher(text: str) -> str:
|
def extract_cypher(text: str) -> str:
|
||||||
# The pattern to find Cypher code enclosed in triple backticks
|
# The pattern to find Cypher code enclosed in triple backticks
|
||||||
@ -33,6 +35,12 @@ class GraphCypherQAChain(Chain):
|
|||||||
qa_chain: LLMChain
|
qa_chain: LLMChain
|
||||||
input_key: str = "query" #: :meta private:
|
input_key: str = "query" #: :meta private:
|
||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
|
top_k: int = 10
|
||||||
|
"""Number of results to return from the query"""
|
||||||
|
return_intermediate_steps: bool = False
|
||||||
|
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||||
|
return_direct: bool = False
|
||||||
|
"""Whether or not to return the result of querying the graph directly."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
@ -74,12 +82,14 @@ class GraphCypherQAChain(Chain):
|
|||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, Any]:
|
||||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
question = inputs[self.input_key]
|
question = inputs[self.input_key]
|
||||||
|
|
||||||
|
intermediate_steps: List = []
|
||||||
|
|
||||||
generated_cypher = self.cypher_generation_chain.run(
|
generated_cypher = self.cypher_generation_chain.run(
|
||||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||||
)
|
)
|
||||||
@ -91,14 +101,30 @@ class GraphCypherQAChain(Chain):
|
|||||||
_run_manager.on_text(
|
_run_manager.on_text(
|
||||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||||
)
|
)
|
||||||
context = self.graph.query(generated_cypher)
|
|
||||||
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
intermediate_steps.append({"query": generated_cypher})
|
||||||
_run_manager.on_text(
|
|
||||||
str(context), color="green", end="\n", verbose=self.verbose
|
# Retrieve and limit the number of results
|
||||||
)
|
context = self.graph.query(generated_cypher)[: self.top_k]
|
||||||
result = self.qa_chain(
|
|
||||||
{"question": question, "context": context},
|
if self.return_direct:
|
||||||
callbacks=callbacks,
|
final_result = context
|
||||||
)
|
else:
|
||||||
return {self.output_key: result[self.qa_chain.output_key]}
|
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
str(context), color="green", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_steps.append({"context": context})
|
||||||
|
|
||||||
|
result = self.qa_chain(
|
||||||
|
{"question": question, "context": context},
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
final_result = result[self.qa_chain.output_key]
|
||||||
|
|
||||||
|
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||||
|
|
||||||
|
return chain_result
|
||||||
|
@ -78,8 +78,7 @@ class Neo4jGraph:
|
|||||||
with self._driver.session(database=self._database) as session:
|
with self._driver.session(database=self._database) as session:
|
||||||
try:
|
try:
|
||||||
data = session.run(query, params)
|
data = session.run(query, params)
|
||||||
# Hard limit of 50 results
|
return [r.data() for r in data]
|
||||||
return [r.data() for r in data][:50]
|
|
||||||
except CypherSyntaxError as e:
|
except CypherSyntaxError as e:
|
||||||
raise ValueError("Generated Cypher Statement is not valid\n" f"{e}")
|
raise ValueError("Generated Cypher Statement is not valid\n" f"{e}")
|
||||||
|
|
||||||
|
@ -58,3 +58,113 @@ def test_cypher_generating_run() -> None:
|
|||||||
output = chain.run("Who played in Pulp Fiction?")
|
output = chain.run("Who played in Pulp Fiction?")
|
||||||
expected_output = " Bruce Willis played in Pulp Fiction."
|
expected_output = " Bruce Willis played in Pulp Fiction."
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cypher_top_k() -> None:
|
||||||
|
"""Test top_k parameter correctly limits the number of results in the context."""
|
||||||
|
url = os.environ.get("NEO4J_URL")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
TOP_K = 1
|
||||||
|
|
||||||
|
graph = Neo4jGraph(
|
||||||
|
url=url,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
)
|
||||||
|
# Delete all nodes in the graph
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
# Create two nodes and a relationship
|
||||||
|
graph.query(
|
||||||
|
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||||
|
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||||
|
"<-[:ACTED_IN]-(:Actor {name:'Foo'})"
|
||||||
|
)
|
||||||
|
# Refresh schema information
|
||||||
|
graph.refresh_schema()
|
||||||
|
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
OpenAI(temperature=0), graph=graph, return_direct=True, top_k=TOP_K
|
||||||
|
)
|
||||||
|
output = chain.run("Who played in Pulp Fiction?")
|
||||||
|
assert len(output) == TOP_K
|
||||||
|
|
||||||
|
|
||||||
|
def test_cypher_intermediate_steps() -> None:
|
||||||
|
"""Test the returning of the intermediate steps."""
|
||||||
|
url = os.environ.get("NEO4J_URL")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = Neo4jGraph(
|
||||||
|
url=url,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
)
|
||||||
|
# Delete all nodes in the graph
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
# Create two nodes and a relationship
|
||||||
|
graph.query(
|
||||||
|
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||||
|
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||||
|
)
|
||||||
|
# Refresh schema information
|
||||||
|
graph.refresh_schema()
|
||||||
|
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
OpenAI(temperature=0), graph=graph, return_intermediate_steps=True
|
||||||
|
)
|
||||||
|
output = chain("Who played in Pulp Fiction?")
|
||||||
|
|
||||||
|
expected_output = " Bruce Willis played in Pulp Fiction."
|
||||||
|
assert output["result"] == expected_output
|
||||||
|
|
||||||
|
query = output["intermediate_steps"][0]["query"]
|
||||||
|
expected_query = (
|
||||||
|
"\n\nMATCH (a:Actor)-[:ACTED_IN]->"
|
||||||
|
"(m:Movie {title: 'Pulp Fiction'}) RETURN a.name"
|
||||||
|
)
|
||||||
|
assert query == expected_query
|
||||||
|
|
||||||
|
context = output["intermediate_steps"][1]["context"]
|
||||||
|
expected_context = [{"a.name": "Bruce Willis"}]
|
||||||
|
assert context == expected_context
|
||||||
|
|
||||||
|
|
||||||
|
def test_cypher_return_direct() -> None:
|
||||||
|
"""Test that chain returns direct results."""
|
||||||
|
url = os.environ.get("NEO4J_URL")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = Neo4jGraph(
|
||||||
|
url=url,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
)
|
||||||
|
# Delete all nodes in the graph
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
# Create two nodes and a relationship
|
||||||
|
graph.query(
|
||||||
|
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||||
|
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||||
|
)
|
||||||
|
# Refresh schema information
|
||||||
|
graph.refresh_schema()
|
||||||
|
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
OpenAI(temperature=0), graph=graph, return_direct=True
|
||||||
|
)
|
||||||
|
output = chain.run("Who played in Pulp Fiction?")
|
||||||
|
expected_output = [{"a.name": "Bruce Willis"}]
|
||||||
|
assert output == expected_output
|
||||||
|
Loading…
Reference in New Issue
Block a user