mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-10 23:41:28 +00:00
community[patch]: Add function response to graph cypher qa chain (#22690)
LLMs struggle with Graph RAG, because it's different from vector RAG in a way that you don't provide the whole context, only the answer and the LLM has to believe. However, that doesn't really work a lot of the time. However, if you wrap the context as function response the accuracy is much better. btw... `union[LLMChain, Runnable]` is linting fun, that's why so many ignores
This commit is contained in:
@@ -2,14 +2,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain_community.chains.graph_qa.cypher_utils import (
|
||||
CypherQueryCorrector,
|
||||
@@ -23,6 +36,12 @@ from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
FUNCTION_RESPONSE_SYSTEM = """You are an assistant that helps to form nice and human
|
||||
understandable answers based on the provided information from tools.
|
||||
Do not add any other information that wasn't present in the tools, and use
|
||||
very concise style in interpreting results!
|
||||
"""
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""Extract Cypher code from a text.
|
||||
@@ -104,6 +123,31 @@ def construct_schema(
|
||||
)
|
||||
|
||||
|
||||
def get_function_response(
|
||||
question: str, context: List[Dict[str, Any]]
|
||||
) -> List[BaseMessage]:
|
||||
TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D"
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": TOOL_ID,
|
||||
"function": {
|
||||
"arguments": '{"question":"' + question + '"}',
|
||||
"name": "GetInformation",
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
ToolMessage(content=str(context), tool_call_id=TOOL_ID),
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
class GraphCypherQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating Cypher statements.
|
||||
|
||||
@@ -121,7 +165,7 @@ class GraphCypherQAChain(Chain):
|
||||
|
||||
graph: GraphStore = Field(exclude=True)
|
||||
cypher_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
qa_chain: Union[LLMChain, Runnable]
|
||||
graph_schema: str
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
@@ -133,6 +177,8 @@ class GraphCypherQAChain(Chain):
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
cypher_query_corrector: Optional[CypherQueryCorrector] = None
|
||||
"""Optional cypher validation tool"""
|
||||
use_function_response: bool = False
|
||||
"""Whether to wrap the database context as tool/function response"""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@@ -163,12 +209,14 @@ class GraphCypherQAChain(Chain):
|
||||
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||
cypher_prompt: Optional[BasePromptTemplate] = None,
|
||||
cypher_llm: Optional[BaseLanguageModel] = None,
|
||||
qa_llm: Optional[BaseLanguageModel] = None,
|
||||
qa_llm: Optional[Union[BaseLanguageModel, Any]] = None,
|
||||
exclude_types: List[str] = [],
|
||||
include_types: List[str] = [],
|
||||
validate_cypher: bool = False,
|
||||
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
use_function_response: bool = False,
|
||||
function_response_system: str = FUNCTION_RESPONSE_SYSTEM,
|
||||
**kwargs: Any,
|
||||
) -> GraphCypherQAChain:
|
||||
"""Initialize from LLM."""
|
||||
@@ -205,7 +253,22 @@ class GraphCypherQAChain(Chain):
|
||||
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
|
||||
)
|
||||
|
||||
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
|
||||
qa_llm = qa_llm or llm
|
||||
if use_function_response:
|
||||
try:
|
||||
qa_llm.bind_tools({}) # type: ignore[union-attr]
|
||||
response_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessage(content=function_response_system),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
MessagesPlaceholder(variable_name="function_response"),
|
||||
]
|
||||
)
|
||||
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
|
||||
except (NotImplementedError, AttributeError):
|
||||
raise ValueError("Provided LLM does not support native tools/functions")
|
||||
else:
|
||||
qa_chain = LLMChain(llm=qa_llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
|
||||
|
||||
cypher_generation_chain = LLMChain(
|
||||
llm=cypher_llm or llm, # type: ignore[arg-type]
|
||||
@@ -217,7 +280,6 @@ class GraphCypherQAChain(Chain):
|
||||
"Either `exclude_types` or `include_types` "
|
||||
"can be provided, but not both"
|
||||
)
|
||||
|
||||
graph_schema = construct_schema(
|
||||
kwargs["graph"].get_structured_schema, include_types, exclude_types
|
||||
)
|
||||
@@ -235,6 +297,7 @@ class GraphCypherQAChain(Chain):
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
cypher_query_corrector=cypher_query_corrector,
|
||||
use_function_response=use_function_response,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -284,12 +347,17 @@ class GraphCypherQAChain(Chain):
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
if self.use_function_response:
|
||||
function_response = get_function_response(question, context)
|
||||
final_result = self.qa_chain.invoke( # type: ignore
|
||||
{"question": question, "function_response": function_response},
|
||||
)
|
||||
else:
|
||||
result = self.qa_chain.invoke( # type: ignore
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key] # type: ignore
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
|
Reference in New Issue
Block a user