Add neo4j vector memory template (#12993)

This commit is contained in:
Tomaz Bratanic
2023-11-07 22:00:49 +01:00
committed by GitHub
parent 5ac2fc5bb2
commit 13bd83bd61
10 changed files with 2136 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from neo4j_vector_memory.chain import chain
__all__ = ["chain"]

View File

@@ -0,0 +1,69 @@
from operator import itemgetter
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.vectorstores import Neo4jVector
from neo4j_vector_memory.history import get_history, save_history
# Define vector retrieval
retrieval_query = "RETURN node.text AS text, score, {id:elementId(node)} AS metadata"
vectorstore = Neo4jVector.from_existing_index(
OpenAIEmbeddings(), index_name="dune", retrieval_query=retrieval_query
)
retriever = vectorstore.as_retriever()
# Define LLM
llm = ChatOpenAI()
# Condense a chat history and follow-up question into a standalone question
condense_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Make sure to include all the relevant information.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:""" # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_template)
# RAG answer synthesis prompt
answer_template = """Answer the question based only on the following context:
<context>
{context}
</context>"""
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
[
("system", answer_template),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{question}"),
]
)
chain = (
RunnablePassthrough.assign(chat_history=get_history)
| RunnablePassthrough.assign(
rephrased_question=CONDENSE_QUESTION_PROMPT | llm | StrOutputParser()
)
| RunnablePassthrough.assign(
context=itemgetter("rephrased_question") | retriever,
)
| RunnablePassthrough.assign(
output=ANSWER_PROMPT | llm | StrOutputParser(),
)
| save_history
)
# Add typing for input
class Question(BaseModel):
question: str
user_id: str
session_id: str
chain = chain.with_types(input_type=Question)

View File

@@ -0,0 +1,79 @@
from typing import Any, Dict, List, Union
from langchain.graphs import Neo4jGraph
from langchain.memory import ChatMessageHistory
from langchain.schema import AIMessage, HumanMessage
graph = Neo4jGraph()
def convert_messages(input: List[Dict[str, Any]]) -> ChatMessageHistory:
history = ChatMessageHistory()
for item in input:
history.add_user_message(item["result"]["question"])
history.add_ai_message(item["result"]["answer"])
return history
def get_history(input: Dict[str, Any]) -> List[Union[HumanMessage, AIMessage]]:
# Lookback conversation window
window = 3
data = graph.query(
"""
MATCH (u:User {id:$user_id})-[:HAS_SESSION]->(s:Session {id:$session_id}),
(s)-[:LAST_MESSAGE]->(last_message)
MATCH p=(last_message)<-[:NEXT*0.."""
+ str(window)
+ """]-()
WITH p, length(p) AS length
ORDER BY length DESC LIMIT 1
UNWIND reverse(nodes(p)) AS node
MATCH (node)-[:HAS_ANSWER]->(answer)
RETURN {question:node.text, answer:answer.text} AS result
""",
params=input,
)
history = convert_messages(data)
return history.messages
def save_history(input: Dict[str, Any]) -> str:
input["context"] = [el.metadata["id"] for el in input["context"]]
has_history = bool(input.pop("chat_history"))
# store history to database
if has_history:
graph.query(
"""
MATCH (u:User {id: $user_id})-[:HAS_SESSION]->(s:Session{id: $session_id}),
(s)-[l:LAST_MESSAGE]->(last_message)
CREATE (last_message)-[:NEXT]->(q:Question
{text:$question, rephrased:$rephrased_question, date:datetime()}),
(q)-[:HAS_ANSWER]->(:Answer {text:$output}),
(s)-[:LAST_MESSAGE]->(q)
DELETE l
WITH q
UNWIND $context AS c
MATCH (n) WHERE elementId(n) = c
MERGE (q)-[:RETRIEVED]->(n)
""",
params=input,
)
else:
graph.query(
"""MERGE (u:User {id: $user_id})
CREATE (u)-[:HAS_SESSION]->(s1:Session {id:$session_id}),
(s1)-[:LAST_MESSAGE]->(q:Question
{text:$question, rephrased:$rephrased_question, date:datetime()}),
(q)-[:HAS_ANSWER]->(:Answer {text:$output})
WITH q
UNWIND $context AS c
MATCH (n) WHERE elementId(n) = c
MERGE (q)-[:RETRIEVED]->(n)
""",
params=input,
)
# Return LLM response to the chain
return input["output"]