Update neo4j cypher templates to the function callback (#20515)

Update Neo4j Cypher templates to use function callback to pass context
instead of passing it in user prompt.

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Tomaz Bratanic
2024-04-19 20:33:32 +02:00
committed by GitHub
parent 3d9b26fc28
commit e4b38e2822
9 changed files with 699 additions and 88 deletions

View File

@@ -1,13 +1,22 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain.memory import ChatMessageHistory
from langchain_community.chat_models import ChatOpenAI
from langchain_community.graphs import Neo4jGraph
from langchain_core.messages import (
AIMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
# Connection to Neo4j
graph = Neo4jGraph()
@@ -56,7 +65,9 @@ def get_history(input: Dict[str, Any]) -> ChatMessageHistory:
def save_history(input):
input.pop("response")
print(input)
if input.get("function_response"):
input.pop("function_response")
# store history to database
graph.query(
"""MERGE (u:User {id: $user_id})
@@ -107,26 +118,51 @@ cypher_response = (
)
# Generate natural language response based on database results
response_template = """Based on the the question, Cypher query, and Cypher response, write a natural language response:
Question: {question}
Cypher query: {query}
Cypher Response: {response}""" # noqa: E501
response_system = """You are an assistant that helps to form nice and human
understandable answers based on the provided information from tools.
Do not add any other information that wasn't present in the tools, and use
very concise style in interpreting results!
"""
response_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"Given an input question and Cypher response, convert it to a "
"natural language answer. No pre-amble.",
),
("human", response_template),
SystemMessage(content=response_system),
HumanMessagePromptTemplate.from_template("{question}"),
MessagesPlaceholder(variable_name="function_response"),
]
)
def get_function_response(
query: str, question: str
) -> List[Union[AIMessage, ToolMessage]]:
context = graph.query(cypher_validation(query))
TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D"
messages = [
AIMessage(
content="",
additional_kwargs={
"tool_calls": [
{
"id": TOOL_ID,
"function": {
"arguments": '{"question":"' + question + '"}',
"name": "GetInformation",
},
"type": "function",
}
]
},
),
ToolMessage(content=str(context), tool_call_id=TOOL_ID),
]
return messages
chain = (
RunnablePassthrough.assign(query=cypher_response)
| RunnablePassthrough.assign(
response=lambda x: graph.query(cypher_validation(x["query"])),
function_response=lambda x: get_function_response(x["query"], x["question"]),
)
| RunnablePassthrough.assign(
output=response_prompt | qa_llm | StrOutputParser(),