mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
community[patch]: Add function response to graph cypher qa chain (#22690)
LLMs struggle with Graph RAG, because it's different from vector RAG in a way that you don't provide the whole context, only the answer and the LLM has to believe. However, that doesn't really work a lot of the time. However, if you wrap the context as function response the accuracy is much better. btw... `union[LLMChain, Runnable]` is linting fun, that's why so many ignores
This commit is contained in:
parent
34edfe4a16
commit
76a193decc
@ -164,10 +164,10 @@
|
|||||||
"text": [
|
"text": [
|
||||||
"Node properties:\n",
|
"Node properties:\n",
|
||||||
"- **Movie**\n",
|
"- **Movie**\n",
|
||||||
" - `runtime: INTEGER` Min: 120, Max: 120\n",
|
" - `runtime`: INTEGER Min: 120, Max: 120\n",
|
||||||
" - `name: STRING` Available options: ['Top Gun']\n",
|
" - `name`: STRING Available options: ['Top Gun']\n",
|
||||||
"- **Actor**\n",
|
"- **Actor**\n",
|
||||||
" - `name: STRING` Available options: ['Tom Cruise', 'Val Kilmer', 'Anthony Edwards', 'Meg Ryan']\n",
|
" - `name`: STRING Available options: ['Tom Cruise', 'Val Kilmer', 'Anthony Edwards', 'Meg Ryan']\n",
|
||||||
"Relationship properties:\n",
|
"Relationship properties:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The relationships:\n",
|
"The relationships:\n",
|
||||||
@ -225,7 +225,7 @@
|
|||||||
"WHERE m.name = 'Top Gun'\n",
|
"WHERE m.name = 'Top Gun'\n",
|
||||||
"RETURN a.name\u001b[0m\n",
|
"RETURN a.name\u001b[0m\n",
|
||||||
"Full Context:\n",
|
"Full Context:\n",
|
||||||
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Tom Cruise'}]\u001b[0m\n",
|
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -234,7 +234,7 @@
|
|||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"{'query': 'Who played in Top Gun?',\n",
|
"{'query': 'Who played in Top Gun?',\n",
|
||||||
" 'result': 'Anthony Edwards, Meg Ryan, Val Kilmer, Tom Cruise played in Top Gun.'}"
|
" 'result': 'Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan played in Top Gun.'}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 8,
|
"execution_count": 8,
|
||||||
@ -286,7 +286,7 @@
|
|||||||
"WHERE m.name = 'Top Gun'\n",
|
"WHERE m.name = 'Top Gun'\n",
|
||||||
"RETURN a.name\u001b[0m\n",
|
"RETURN a.name\u001b[0m\n",
|
||||||
"Full Context:\n",
|
"Full Context:\n",
|
||||||
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n",
|
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}]\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -295,7 +295,7 @@
|
|||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"{'query': 'Who played in Top Gun?',\n",
|
"{'query': 'Who played in Top Gun?',\n",
|
||||||
" 'result': 'Anthony Edwards, Meg Ryan played in Top Gun.'}"
|
" 'result': 'Tom Cruise, Val Kilmer played in Top Gun.'}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 10,
|
"execution_count": 10,
|
||||||
@ -346,11 +346,11 @@
|
|||||||
"WHERE m.name = 'Top Gun'\n",
|
"WHERE m.name = 'Top Gun'\n",
|
||||||
"RETURN a.name\u001b[0m\n",
|
"RETURN a.name\u001b[0m\n",
|
||||||
"Full Context:\n",
|
"Full Context:\n",
|
||||||
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Tom Cruise'}]\u001b[0m\n",
|
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
"Intermediate steps: [{'query': \"MATCH (a:Actor)-[:ACTED_IN]->(m:Movie)\\nWHERE m.name = 'Top Gun'\\nRETURN a.name\"}, {'context': [{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Tom Cruise'}]}]\n",
|
"Intermediate steps: [{'query': \"MATCH (a:Actor)-[:ACTED_IN]->(m:Movie)\\nWHERE m.name = 'Top Gun'\\nRETURN a.name\"}, {'context': [{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]}]\n",
|
||||||
"Final answer: Anthony Edwards, Meg Ryan, Val Kilmer, Tom Cruise played in Top Gun.\n"
|
"Final answer: Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan played in Top Gun.\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -406,10 +406,10 @@
|
|||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"{'query': 'Who played in Top Gun?',\n",
|
"{'query': 'Who played in Top Gun?',\n",
|
||||||
" 'result': [{'a.name': 'Anthony Edwards'},\n",
|
" 'result': [{'a.name': 'Tom Cruise'},\n",
|
||||||
" {'a.name': 'Meg Ryan'},\n",
|
|
||||||
" {'a.name': 'Val Kilmer'},\n",
|
" {'a.name': 'Val Kilmer'},\n",
|
||||||
" {'a.name': 'Tom Cruise'}]}"
|
" {'a.name': 'Anthony Edwards'},\n",
|
||||||
|
" {'a.name': 'Meg Ryan'}]}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 14,
|
"execution_count": 14,
|
||||||
@ -482,7 +482,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
|
"\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
|
||||||
"Generated Cypher:\n",
|
"Generated Cypher:\n",
|
||||||
"\u001b[32;1m\u001b[1;3mMATCH (:Movie {name:\"Top Gun\"})<-[:ACTED_IN]-()\n",
|
"\u001b[32;1m\u001b[1;3mMATCH (m:Movie {name:\"Top Gun\"})<-[:ACTED_IN]-()\n",
|
||||||
"RETURN count(*) AS numberOfActors\u001b[0m\n",
|
"RETURN count(*) AS numberOfActors\u001b[0m\n",
|
||||||
"Full Context:\n",
|
"Full Context:\n",
|
||||||
"\u001b[32;1m\u001b[1;3m[{'numberOfActors': 4}]\u001b[0m\n",
|
"\u001b[32;1m\u001b[1;3m[{'numberOfActors': 4}]\u001b[0m\n",
|
||||||
@ -494,7 +494,7 @@
|
|||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"{'query': 'How many people played in Top Gun?',\n",
|
"{'query': 'How many people played in Top Gun?',\n",
|
||||||
" 'result': 'There were 4 actors who played in Top Gun.'}"
|
" 'result': 'There were 4 actors in Top Gun.'}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 16,
|
"execution_count": 16,
|
||||||
@ -548,7 +548,7 @@
|
|||||||
"WHERE m.name = 'Top Gun'\n",
|
"WHERE m.name = 'Top Gun'\n",
|
||||||
"RETURN a.name\u001b[0m\n",
|
"RETURN a.name\u001b[0m\n",
|
||||||
"Full Context:\n",
|
"Full Context:\n",
|
||||||
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Tom Cruise'}]\u001b[0m\n",
|
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -557,7 +557,7 @@
|
|||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"{'query': 'Who played in Top Gun?',\n",
|
"{'query': 'Who played in Top Gun?',\n",
|
||||||
" 'result': 'Anthony Edwards, Meg Ryan, Val Kilmer, and Tom Cruise played in Top Gun.'}"
|
" 'result': 'Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan played in Top Gun.'}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 18,
|
"execution_count": 18,
|
||||||
@ -661,7 +661,7 @@
|
|||||||
"WHERE m.name = 'Top Gun'\n",
|
"WHERE m.name = 'Top Gun'\n",
|
||||||
"RETURN a.name\u001b[0m\n",
|
"RETURN a.name\u001b[0m\n",
|
||||||
"Full Context:\n",
|
"Full Context:\n",
|
||||||
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Tom Cruise'}]\u001b[0m\n",
|
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@ -670,7 +670,7 @@
|
|||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"{'query': 'Who played in Top Gun?',\n",
|
"{'query': 'Who played in Top Gun?',\n",
|
||||||
" 'result': 'Anthony Edwards, Meg Ryan, Val Kilmer, Tom Cruise played in Top Gun.'}"
|
" 'result': 'Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan played in Top Gun.'}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 22,
|
"execution_count": 22,
|
||||||
@ -683,12 +683,116 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "markdown",
|
||||||
"execution_count": null,
|
"id": "81093062-eb7f-4d96-b1fd-c36b8f1b9474",
|
||||||
"id": "3fa3f3d5-f7e7-4ca9-8f07-ca22b897f192",
|
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"source": [
|
||||||
"source": []
|
"## Provide context from database results as tool/function output\n",
|
||||||
|
"\n",
|
||||||
|
"You can use the `use_function_response` parameter to pass context from database results to an LLM as a tool/function output. This method improves the response accuracy and relevance of an answer as the LLM follows the provided context more closely.\n",
|
||||||
|
"_You will need to use an LLM with native function calling support to use this feature_."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"id": "2be8f51c-e80a-4a60-ab1c-266450fc17cd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
|
||||||
|
"Generated Cypher:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mMATCH (a:Actor)-[:ACTED_IN]->(m:Movie)\n",
|
||||||
|
"WHERE m.name = 'Top Gun'\n",
|
||||||
|
"RETURN a.name\u001b[0m\n",
|
||||||
|
"Full Context:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'query': 'Who played in Top Gun?',\n",
|
||||||
|
" 'result': 'The main actors in Top Gun are Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan.'}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain = GraphCypherQAChain.from_llm(\n",
|
||||||
|
" llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"),\n",
|
||||||
|
" graph=graph,\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
" use_function_response=True,\n",
|
||||||
|
")\n",
|
||||||
|
"chain.invoke({\"query\": \"Who played in Top Gun?\"})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "48a75785-5bc9-49a7-a41b-88bf3ef9d312",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can provide custom system message when using the function response feature by providing `function_response_system` to instruct the model on how to generate answers.\n",
|
||||||
|
"\n",
|
||||||
|
"_Note that `qa_prompt` will have no effect when using `use_function_response`_"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"id": "ddf0a61e-f104-4dbb-abbf-e65f3f57dd9a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
|
||||||
|
"Generated Cypher:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mMATCH (a:Actor)-[:ACTED_IN]->(m:Movie)\n",
|
||||||
|
"WHERE m.name = 'Top Gun'\n",
|
||||||
|
"RETURN a.name\u001b[0m\n",
|
||||||
|
"Full Context:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'query': 'Who played in Top Gun?',\n",
|
||||||
|
" 'result': \"Arrr matey! In the film Top Gun, ye be seein' Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan sailin' the high seas of the sky! Aye, they be a fine crew of actors, they be!\"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain = GraphCypherQAChain.from_llm(\n",
|
||||||
|
" llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"),\n",
|
||||||
|
" graph=graph,\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
" use_function_response=True,\n",
|
||||||
|
" function_response_system=\"Respond as a pirate!\",\n",
|
||||||
|
")\n",
|
||||||
|
"chain.invoke({\"query\": \"Who played in Top Gun?\"})"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -2,14 +2,27 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.prompts import (
|
||||||
|
BasePromptTemplate,
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
MessagesPlaceholder,
|
||||||
|
)
|
||||||
from langchain_core.pydantic_v1 import Field
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
|
||||||
from langchain_community.chains.graph_qa.cypher_utils import (
|
from langchain_community.chains.graph_qa.cypher_utils import (
|
||||||
CypherQueryCorrector,
|
CypherQueryCorrector,
|
||||||
@ -23,6 +36,12 @@ from langchain_community.graphs.graph_store import GraphStore
|
|||||||
|
|
||||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||||
|
|
||||||
|
FUNCTION_RESPONSE_SYSTEM = """You are an assistant that helps to form nice and human
|
||||||
|
understandable answers based on the provided information from tools.
|
||||||
|
Do not add any other information that wasn't present in the tools, and use
|
||||||
|
very concise style in interpreting results!
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def extract_cypher(text: str) -> str:
|
def extract_cypher(text: str) -> str:
|
||||||
"""Extract Cypher code from a text.
|
"""Extract Cypher code from a text.
|
||||||
@ -104,6 +123,31 @@ def construct_schema(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_response(
|
||||||
|
question: str, context: List[Dict[str, Any]]
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D"
|
||||||
|
messages = [
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": TOOL_ID,
|
||||||
|
"function": {
|
||||||
|
"arguments": '{"question":"' + question + '"}',
|
||||||
|
"name": "GetInformation",
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ToolMessage(content=str(context), tool_call_id=TOOL_ID),
|
||||||
|
]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
class GraphCypherQAChain(Chain):
|
class GraphCypherQAChain(Chain):
|
||||||
"""Chain for question-answering against a graph by generating Cypher statements.
|
"""Chain for question-answering against a graph by generating Cypher statements.
|
||||||
|
|
||||||
@ -121,7 +165,7 @@ class GraphCypherQAChain(Chain):
|
|||||||
|
|
||||||
graph: GraphStore = Field(exclude=True)
|
graph: GraphStore = Field(exclude=True)
|
||||||
cypher_generation_chain: LLMChain
|
cypher_generation_chain: LLMChain
|
||||||
qa_chain: LLMChain
|
qa_chain: Union[LLMChain, Runnable]
|
||||||
graph_schema: str
|
graph_schema: str
|
||||||
input_key: str = "query" #: :meta private:
|
input_key: str = "query" #: :meta private:
|
||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
@ -133,6 +177,8 @@ class GraphCypherQAChain(Chain):
|
|||||||
"""Whether or not to return the result of querying the graph directly."""
|
"""Whether or not to return the result of querying the graph directly."""
|
||||||
cypher_query_corrector: Optional[CypherQueryCorrector] = None
|
cypher_query_corrector: Optional[CypherQueryCorrector] = None
|
||||||
"""Optional cypher validation tool"""
|
"""Optional cypher validation tool"""
|
||||||
|
use_function_response: bool = False
|
||||||
|
"""Whether to wrap the database context as tool/function response"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
@ -163,12 +209,14 @@ class GraphCypherQAChain(Chain):
|
|||||||
qa_prompt: Optional[BasePromptTemplate] = None,
|
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||||
cypher_prompt: Optional[BasePromptTemplate] = None,
|
cypher_prompt: Optional[BasePromptTemplate] = None,
|
||||||
cypher_llm: Optional[BaseLanguageModel] = None,
|
cypher_llm: Optional[BaseLanguageModel] = None,
|
||||||
qa_llm: Optional[BaseLanguageModel] = None,
|
qa_llm: Optional[Union[BaseLanguageModel, Any]] = None,
|
||||||
exclude_types: List[str] = [],
|
exclude_types: List[str] = [],
|
||||||
include_types: List[str] = [],
|
include_types: List[str] = [],
|
||||||
validate_cypher: bool = False,
|
validate_cypher: bool = False,
|
||||||
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
|
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
|
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
use_function_response: bool = False,
|
||||||
|
function_response_system: str = FUNCTION_RESPONSE_SYSTEM,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> GraphCypherQAChain:
|
) -> GraphCypherQAChain:
|
||||||
"""Initialize from LLM."""
|
"""Initialize from LLM."""
|
||||||
@ -205,7 +253,22 @@ class GraphCypherQAChain(Chain):
|
|||||||
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
|
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
|
||||||
)
|
)
|
||||||
|
|
||||||
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
|
qa_llm = qa_llm or llm
|
||||||
|
if use_function_response:
|
||||||
|
try:
|
||||||
|
qa_llm.bind_tools({}) # type: ignore[union-attr]
|
||||||
|
response_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessage(content=function_response_system),
|
||||||
|
HumanMessagePromptTemplate.from_template("{question}"),
|
||||||
|
MessagesPlaceholder(variable_name="function_response"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
|
||||||
|
except (NotImplementedError, AttributeError):
|
||||||
|
raise ValueError("Provided LLM does not support native tools/functions")
|
||||||
|
else:
|
||||||
|
qa_chain = LLMChain(llm=qa_llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
|
||||||
|
|
||||||
cypher_generation_chain = LLMChain(
|
cypher_generation_chain = LLMChain(
|
||||||
llm=cypher_llm or llm, # type: ignore[arg-type]
|
llm=cypher_llm or llm, # type: ignore[arg-type]
|
||||||
@ -217,7 +280,6 @@ class GraphCypherQAChain(Chain):
|
|||||||
"Either `exclude_types` or `include_types` "
|
"Either `exclude_types` or `include_types` "
|
||||||
"can be provided, but not both"
|
"can be provided, but not both"
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_schema = construct_schema(
|
graph_schema = construct_schema(
|
||||||
kwargs["graph"].get_structured_schema, include_types, exclude_types
|
kwargs["graph"].get_structured_schema, include_types, exclude_types
|
||||||
)
|
)
|
||||||
@ -235,6 +297,7 @@ class GraphCypherQAChain(Chain):
|
|||||||
qa_chain=qa_chain,
|
qa_chain=qa_chain,
|
||||||
cypher_generation_chain=cypher_generation_chain,
|
cypher_generation_chain=cypher_generation_chain,
|
||||||
cypher_query_corrector=cypher_query_corrector,
|
cypher_query_corrector=cypher_query_corrector,
|
||||||
|
use_function_response=use_function_response,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -284,12 +347,17 @@ class GraphCypherQAChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
intermediate_steps.append({"context": context})
|
intermediate_steps.append({"context": context})
|
||||||
|
if self.use_function_response:
|
||||||
result = self.qa_chain(
|
function_response = get_function_response(question, context)
|
||||||
{"question": question, "context": context},
|
final_result = self.qa_chain.invoke( # type: ignore
|
||||||
callbacks=callbacks,
|
{"question": question, "function_response": function_response},
|
||||||
)
|
)
|
||||||
final_result = result[self.qa_chain.output_key]
|
else:
|
||||||
|
result = self.qa_chain.invoke( # type: ignore
|
||||||
|
{"question": question, "context": context},
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
final_result = result[self.qa_chain.output_key] # type: ignore
|
||||||
|
|
||||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
|
@ -60,7 +60,7 @@ def test_graph_cypher_qa_chain_prompt_selection_1() -> None:
|
|||||||
qa_prompt=qa_prompt,
|
qa_prompt=qa_prompt,
|
||||||
cypher_prompt=cypher_prompt,
|
cypher_prompt=cypher_prompt,
|
||||||
)
|
)
|
||||||
assert chain.qa_chain.prompt == qa_prompt
|
assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr]
|
||||||
assert chain.cypher_generation_chain.prompt == cypher_prompt
|
assert chain.cypher_generation_chain.prompt == cypher_prompt
|
||||||
|
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ def test_graph_cypher_qa_chain_prompt_selection_2() -> None:
|
|||||||
verbose=True,
|
verbose=True,
|
||||||
return_intermediate_steps=False,
|
return_intermediate_steps=False,
|
||||||
)
|
)
|
||||||
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT
|
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT # type: ignore[union-attr]
|
||||||
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
|
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
|
||||||
|
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ def test_graph_cypher_qa_chain_prompt_selection_3() -> None:
|
|||||||
cypher_llm_kwargs={"memory": readonlymemory},
|
cypher_llm_kwargs={"memory": readonlymemory},
|
||||||
qa_llm_kwargs={"memory": readonlymemory},
|
qa_llm_kwargs={"memory": readonlymemory},
|
||||||
)
|
)
|
||||||
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT
|
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT # type: ignore[union-attr]
|
||||||
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
|
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
|
||||||
|
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ def test_graph_cypher_qa_chain_prompt_selection_4() -> None:
|
|||||||
cypher_llm_kwargs={"prompt": cypher_prompt, "memory": readonlymemory},
|
cypher_llm_kwargs={"prompt": cypher_prompt, "memory": readonlymemory},
|
||||||
qa_llm_kwargs={"prompt": qa_prompt, "memory": readonlymemory},
|
qa_llm_kwargs={"prompt": qa_prompt, "memory": readonlymemory},
|
||||||
)
|
)
|
||||||
assert chain.qa_chain.prompt == qa_prompt
|
assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr]
|
||||||
assert chain.cypher_generation_chain.prompt == cypher_prompt
|
assert chain.cypher_generation_chain.prompt == cypher_prompt
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user