community[patch]: Add function response to graph cypher qa chain (#22690)

LLMs struggle with Graph RAG, because it's different from vector RAG in
a way that you don't provide the whole context, only the answer and the
LLM has to believe. However, that doesn't really work a lot of the time.
However, if you wrap the context as function response the accuracy is
much better.

btw... `union[LLMChain, Runnable]` is linting fun, that's why so many
ignores
This commit is contained in:
Tomaz Bratanic
2024-06-10 13:52:17 -07:00
committed by GitHub
parent 34edfe4a16
commit 76a193decc
3 changed files with 212 additions and 40 deletions

View File

@@ -2,14 +2,27 @@
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.messages import (
AIMessage,
BaseMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import Runnable
from langchain_community.chains.graph_qa.cypher_utils import (
CypherQueryCorrector,
@@ -23,6 +36,12 @@ from langchain_community.graphs.graph_store import GraphStore
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
FUNCTION_RESPONSE_SYSTEM = """You are an assistant that helps to form nice and human
understandable answers based on the provided information from tools.
Do not add any other information that wasn't present in the tools, and use
very concise style in interpreting results!
"""
def extract_cypher(text: str) -> str:
"""Extract Cypher code from a text.
@@ -104,6 +123,31 @@ def construct_schema(
)
def get_function_response(
question: str, context: List[Dict[str, Any]]
) -> List[BaseMessage]:
TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D"
messages = [
AIMessage(
content="",
additional_kwargs={
"tool_calls": [
{
"id": TOOL_ID,
"function": {
"arguments": '{"question":"' + question + '"}',
"name": "GetInformation",
},
"type": "function",
}
]
},
),
ToolMessage(content=str(context), tool_call_id=TOOL_ID),
]
return messages
class GraphCypherQAChain(Chain):
"""Chain for question-answering against a graph by generating Cypher statements.
@@ -121,7 +165,7 @@ class GraphCypherQAChain(Chain):
graph: GraphStore = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
qa_chain: Union[LLMChain, Runnable]
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@@ -133,6 +177,8 @@ class GraphCypherQAChain(Chain):
"""Whether or not to return the result of querying the graph directly."""
cypher_query_corrector: Optional[CypherQueryCorrector] = None
"""Optional cypher validation tool"""
use_function_response: bool = False
"""Whether to wrap the database context as tool/function response"""
@property
def input_keys(self) -> List[str]:
@@ -163,12 +209,14 @@ class GraphCypherQAChain(Chain):
qa_prompt: Optional[BasePromptTemplate] = None,
cypher_prompt: Optional[BasePromptTemplate] = None,
cypher_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[Union[BaseLanguageModel, Any]] = None,
exclude_types: List[str] = [],
include_types: List[str] = [],
validate_cypher: bool = False,
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
use_function_response: bool = False,
function_response_system: str = FUNCTION_RESPONSE_SYSTEM,
**kwargs: Any,
) -> GraphCypherQAChain:
"""Initialize from LLM."""
@@ -205,7 +253,22 @@ class GraphCypherQAChain(Chain):
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
)
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
qa_llm = qa_llm or llm
if use_function_response:
try:
qa_llm.bind_tools({}) # type: ignore[union-attr]
response_prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=function_response_system),
HumanMessagePromptTemplate.from_template("{question}"),
MessagesPlaceholder(variable_name="function_response"),
]
)
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
except (NotImplementedError, AttributeError):
raise ValueError("Provided LLM does not support native tools/functions")
else:
qa_chain = LLMChain(llm=qa_llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
cypher_generation_chain = LLMChain(
llm=cypher_llm or llm, # type: ignore[arg-type]
@@ -217,7 +280,6 @@ class GraphCypherQAChain(Chain):
"Either `exclude_types` or `include_types` "
"can be provided, but not both"
)
graph_schema = construct_schema(
kwargs["graph"].get_structured_schema, include_types, exclude_types
)
@@ -235,6 +297,7 @@ class GraphCypherQAChain(Chain):
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
cypher_query_corrector=cypher_query_corrector,
use_function_response=use_function_response,
**kwargs,
)
@@ -284,12 +347,17 @@ class GraphCypherQAChain(Chain):
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
if self.use_function_response:
function_response = get_function_response(question, context)
final_result = self.qa_chain.invoke( # type: ignore
{"question": question, "function_response": function_response},
)
else:
result = self.qa_chain.invoke( # type: ignore
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key] # type: ignore
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:

View File

@@ -60,7 +60,7 @@ def test_graph_cypher_qa_chain_prompt_selection_1() -> None:
qa_prompt=qa_prompt,
cypher_prompt=cypher_prompt,
)
assert chain.qa_chain.prompt == qa_prompt
assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr]
assert chain.cypher_generation_chain.prompt == cypher_prompt
@@ -72,7 +72,7 @@ def test_graph_cypher_qa_chain_prompt_selection_2() -> None:
verbose=True,
return_intermediate_steps=False,
)
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT # type: ignore[union-attr]
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
@@ -88,7 +88,7 @@ def test_graph_cypher_qa_chain_prompt_selection_3() -> None:
cypher_llm_kwargs={"memory": readonlymemory},
qa_llm_kwargs={"memory": readonlymemory},
)
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT # type: ignore[union-attr]
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
@@ -108,7 +108,7 @@ def test_graph_cypher_qa_chain_prompt_selection_4() -> None:
cypher_llm_kwargs={"prompt": cypher_prompt, "memory": readonlymemory},
qa_llm_kwargs={"prompt": qa_prompt, "memory": readonlymemory},
)
assert chain.qa_chain.prompt == qa_prompt
assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr]
assert chain.cypher_generation_chain.prompt == cypher_prompt