mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 02:50:31 +00:00
Update neo4j cypher templates to the function callback (#20515)
Update Neo4j Cypher templates to use function callback to pass context instead of passing it in user prompt. Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,13 +1,22 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
||||
from langchain.memory import ChatMessageHistory
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
# Connection to Neo4j
|
||||
graph = Neo4jGraph()
|
||||
@@ -56,7 +65,9 @@ def get_history(input: Dict[str, Any]) -> ChatMessageHistory:
|
||||
|
||||
|
||||
def save_history(input):
|
||||
input.pop("response")
|
||||
print(input)
|
||||
if input.get("function_response"):
|
||||
input.pop("function_response")
|
||||
# store history to database
|
||||
graph.query(
|
||||
"""MERGE (u:User {id: $user_id})
|
||||
@@ -107,26 +118,51 @@ cypher_response = (
|
||||
)
|
||||
|
||||
# Generate natural language response based on database results
|
||||
response_template = """Based on the the question, Cypher query, and Cypher response, write a natural language response:
|
||||
Question: {question}
|
||||
Cypher query: {query}
|
||||
Cypher Response: {response}""" # noqa: E501
|
||||
response_system = """You are an assistant that helps to form nice and human
|
||||
understandable answers based on the provided information from tools.
|
||||
Do not add any other information that wasn't present in the tools, and use
|
||||
very concise style in interpreting results!
|
||||
"""
|
||||
|
||||
response_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"Given an input question and Cypher response, convert it to a "
|
||||
"natural language answer. No pre-amble.",
|
||||
),
|
||||
("human", response_template),
|
||||
SystemMessage(content=response_system),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
MessagesPlaceholder(variable_name="function_response"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_function_response(
|
||||
query: str, question: str
|
||||
) -> List[Union[AIMessage, ToolMessage]]:
|
||||
context = graph.query(cypher_validation(query))
|
||||
TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D"
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": TOOL_ID,
|
||||
"function": {
|
||||
"arguments": '{"question":"' + question + '"}',
|
||||
"name": "GetInformation",
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
ToolMessage(content=str(context), tool_call_id=TOOL_ID),
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
chain = (
|
||||
RunnablePassthrough.assign(query=cypher_response)
|
||||
| RunnablePassthrough.assign(
|
||||
response=lambda x: graph.query(cypher_validation(x["query"])),
|
||||
function_response=lambda x: get_function_response(x["query"], x["question"]),
|
||||
)
|
||||
| RunnablePassthrough.assign(
|
||||
output=response_prompt | qa_llm | StrOutputParser(),
|
||||
|
Reference in New Issue
Block a user