Format Templates (#12396)

This commit is contained in:
Erick Friis 2023-10-26 19:44:30 -07:00 committed by GitHub
parent 25c98dbba9
commit 4b16601d33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 800 additions and 441 deletions

View File

@ -1,2 +1,8 @@
lint lint_diff:
poetry run ruff .
poetry run poe lint
test:
poetry run poe test
format:
poetry run poe format

View File

@ -1,7 +1,7 @@
from langchain.schema.runnable import ConfigurableField
from .chain import chain
from .retriever_agent import executor
from .chain import chain
final_chain = chain.configurable_alternatives(
ConfigurableField(id="chain"),

View File

@ -1,5 +1,5 @@
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatAnthropic
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from .prompts import answer_prompt

View File

@ -1,6 +1,7 @@
from langchain.schema.agent import AgentAction, AgentFinish
import re
from langchain.schema.agent import AgentAction, AgentFinish
from .agent_scratchpad import _format_docs
@ -14,18 +15,23 @@ def extract_between_tags(tag: str, string: str, strip: bool = True) -> str:
# Only return the first one
return ext_list[0]
def parse_output(outputs):
partial_completion = outputs["partial_completion"]
steps = outputs["intermediate_steps"]
search_query = extract_between_tags('search_query', partial_completion + '</search_query>')
search_query = extract_between_tags(
"search_query", partial_completion + "</search_query>"
)
if search_query is None:
docs = []
str_output = ""
for action, observation in steps:
docs.extend(observation)
str_output += action.log
str_output += '</search_query>' + _format_docs(observation)
str_output += "</search_query>" + _format_docs(observation)
str_output += partial_completion
return AgentFinish({"docs": docs, "output": str_output}, log=partial_completion)
else:
return AgentAction(tool="search", tool_input=search_query, log=partial_completion)
return AgentAction(
tool="search", tool_input=search_query, log=partial_completion
)

View File

@ -2,6 +2,6 @@ retrieval_prompt = """{retriever_description} Before beginning to research the u
After each call to the Search Engine Tool, reflect briefly inside <search_quality></search_quality> tags about whether you now have enough information to answer, or whether more information is needed. If you have all the relevant information, write it in <information></information> tags, WITHOUT actually answering the question. Otherwise, issue a new search.
Here is the user's question: <question>{query}</question> Remind yourself to make short queries in your scratchpad as you plan out your strategy."""
Here is the user's question: <question>{query}</question> Remind yourself to make short queries in your scratchpad as you plan out your strategy.""" # noqa: E501
answer_prompt = "Here is a user query: <query>{query}</query>. Here is some relevant information: <information>{information}</information>. Please answer the question using the relevant information."
answer_prompt = "Here is a user query: <query>{query}</query>. Here is some relevant information: <information>{information}</information>. Please answer the question using the relevant information." # noqa: E501

View File

@ -3,13 +3,14 @@ from langchain.tools import tool
# This is used to tell the model how to best use the retriever.
retriever_description = """You will be asked a question by a human user. You have access to the following tool to help answer the question. <tool_description> Search Engine Tool * The search engine will exclusively search over Wikipedia for pages similar to your query. It returns for each page its title and full page content. Use this tool if you want to get up-to-date and comprehensive information on a topic to help answer queries. Queries should be as atomic as possible -- they only need to address one part of the user's question. For example, if the user's query is "what is the color of a basketball?", your search query should be "basketball". Here's another example: if the user's question is "Who created the first neural network?", your first query should be "neural network". As you can see, these queries are quite short. Think keywords, not phrases. * At any time, you can make a call to the search engine using the following syntax: <search_query>query_word</search_query>. * You'll then get results back in <search_result> tags.</tool_description>"""
retriever_description = """You will be asked a question by a human user. You have access to the following tool to help answer the question. <tool_description> Search Engine Tool * The search engine will exclusively search over Wikipedia for pages similar to your query. It returns for each page its title and full page content. Use this tool if you want to get up-to-date and comprehensive information on a topic to help answer queries. Queries should be as atomic as possible -- they only need to address one part of the user's question. For example, if the user's query is "what is the color of a basketball?", your search query should be "basketball". Here's another example: if the user's question is "Who created the first neural network?", your first query should be "neural network". As you can see, these queries are quite short. Think keywords, not phrases. * At any time, you can make a call to the search engine using the following syntax: <search_query>query_word</search_query>. * You'll then get results back in <search_result> tags.</tool_description>""" # noqa: E501
retriever = WikipediaRetriever()
# This should be the same as the function name below
RETRIEVER_TOOL_NAME = "search"
@tool
def search(query):
"""Search with the retriever."""

View File

@ -1,13 +1,13 @@
from langchain.agents import AgentExecutor
from langchain.chat_models import ChatAnthropic
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough, RunnableMap
from langchain.schema.output_parser import StrOutputParser
from langchain.agents import AgentExecutor
from langchain.schema.runnable import RunnableMap, RunnablePassthrough
from .retriever import search, RETRIEVER_TOOL_NAME, retriever_description
from .prompts import retrieval_prompt
from .agent_scratchpad import format_agent_scratchpad
from .output_parser import parse_output
from .prompts import retrieval_prompt
from .retriever import retriever_description, search
prompt = ChatPromptTemplate.from_messages([
("user", retrieval_prompt),

View File

@ -1,6 +1,12 @@
from anthropic_iterative_search import final_chain
from anthropic_iterative_search import final_chain
if __name__ == "__main__":
query = "Which movie came out first: Oppenheimer, or Are You There God It's Me Margaret?"
print(final_chain.with_config(configurable={"chain": "retrieve"}).invoke({"query": query}))
query = (
"Which movie came out first: Oppenheimer, or "
"Are You There God It's Me Margaret?"
)
print(
final_chain.with_config(configurable={"chain": "retrieve"}).invoke(
{"query": query}
)
)

View File

@ -1,14 +1,12 @@
import os
import cassio
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Cassandra
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.vectorstores import Cassandra
use_cassandra = int(os.environ.get("USE_CASSANDRA_CLUSTER", "0"))
if use_cassandra:

View File

@ -1,13 +1,13 @@
import os
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import Cluster
def get_cassandra_connection():
contact_points = [
cp.strip()
for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(',')
for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(",")
if cp.strip()
]
CASSANDRA_KEYSPACE = os.environ["CASSANDRA_KEYSPACE"]
@ -22,6 +22,8 @@ def get_cassandra_connection():
else:
auth_provider = None
c_cluster = Cluster(contact_points if contact_points else None, auth_provider=auth_provider)
c_cluster = Cluster(
contact_points if contact_points else None, auth_provider=auth_provider
)
session = c_cluster.connect()
return (session, CASSANDRA_KEYSPACE)

View File

@ -1,14 +1,13 @@
import os
import cassio
from langchain.vectorstores import Cassandra
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Cassandra
use_cassandra = int(os.environ.get("USE_CASSANDRA_CLUSTER", "0"))
if use_cassandra:
from cassandra_entomology_rag.cassandra_cluster_init import get_cassandra_connection
session, keyspace = get_cassandra_connection()
cassio.init(
session=session,
@ -22,7 +21,7 @@ else:
)
if __name__ == '__main__':
if __name__ == "__main__":
embeddings = OpenAIEmbeddings()
vector_store = Cassandra(
session=None,
@ -32,16 +31,13 @@ if __name__ == '__main__':
)
#
lines = [
l.strip()
for l in open("sources.txt").readlines()
if l.strip()
if l[0] != "#"
line.strip()
for line in open("sources.txt").readlines()
if line.strip()
if line[0] != "#"
]
# deterministic IDs to prevent duplicates on multiple runs
ids = [
"_".join(l.split(" ")[:2]).lower().replace(":", "")
for l in lines
]
ids = ["_".join(line.split(" ")[:2]).lower().replace(":", "") for line in lines]
#
vector_store.add_texts(texts=lines, ids=ids)
print(f"Done ({len(lines)} lines inserted).")

View File

@ -1,13 +1,12 @@
import os
import cassio
import langchain
from langchain.schema import BaseMessage
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.schema.runnable import RunnableLambda
from langchain.cache import CassandraCache
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import BaseMessage
from langchain.schema.runnable import RunnableLambda
use_cassandra = int(os.environ.get("USE_CASSANDRA_CLUSTER", "0"))
if use_cassandra:

View File

@ -1,13 +1,13 @@
import os
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import Cluster
def get_cassandra_connection():
contact_points = [
cp.strip()
for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(',')
for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(",")
if cp.strip()
]
CASSANDRA_KEYSPACE = os.environ["CASSANDRA_KEYSPACE"]
@ -22,6 +22,8 @@ def get_cassandra_connection():
else:
auth_provider = None
c_cluster = Cluster(contact_points if contact_points else None, auth_provider=auth_provider)
c_cluster = Cluster(
contact_points if contact_points else None, auth_provider=auth_provider
)
session = c_cluster.connect()
return (session, CASSANDRA_KEYSPACE)

View File

@ -1,24 +1,25 @@
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_experimental.tools import PythonAstREPLTool
import pandas as pd
from langchain.chat_models import ChatOpenAI
from langsmith import Client
from langchain.smith import RunEvalConfig, run_on_dataset
from pydantic import BaseModel, Field
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.tools.retriever import create_retriever_tool
from pathlib import Path
import pandas as pd
from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools.retriever import create_retriever_tool
from langchain.vectorstores import FAISS
from langchain_experimental.tools import PythonAstREPLTool
from pydantic import BaseModel, Field
MAIN_DIR = Path(__file__).parents[1]
pd.set_option('display.max_rows', 20)
pd.set_option('display.max_columns', 20)
pd.set_option("display.max_rows", 20)
pd.set_option("display.max_columns", 20)
embedding_model = OpenAIEmbeddings()
vectorstore = FAISS.load_local(MAIN_DIR / "titanic_data", embedding_model)
retriever_tool = create_retriever_tool(vectorstore.as_retriever(), "person_name_search", "Search for a person by name")
retriever_tool = create_retriever_tool(
vectorstore.as_retriever(), "person_name_search", "Search for a person by name"
)
TEMPLATE = """You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
@ -41,8 +42,7 @@ For example:
<question>Who has id 320</question>
<logic>Use `python_repl` since even though the question is about a person, you don't know their name so you can't include it.</logic>
"""
""" # noqa: E501
class PythonInputs(BaseModel):
@ -52,15 +52,24 @@ class PythonInputs(BaseModel):
df = pd.read_csv("titanic.csv")
template = TEMPLATE.format(dhead=df.head().to_markdown())
prompt = ChatPromptTemplate.from_messages([
("system", template),
MessagesPlaceholder(variable_name="agent_scratchpad"),
("human", "{input}")
])
prompt = ChatPromptTemplate.from_messages(
[
("system", template),
MessagesPlaceholder(variable_name="agent_scratchpad"),
("human", "{input}"),
]
)
repl = PythonAstREPLTool(locals={"df": df}, name="python_repl",
description="Runs code and returns the output of the final line",
args_schema=PythonInputs)
repl = PythonAstREPLTool(
locals={"df": df},
name="python_repl",
description="Runs code and returns the output of the final line",
args_schema=PythonInputs,
)
tools = [repl, retriever_tool]
agent = OpenAIFunctionsAgent(llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools)
agent_executor = AgentExecutor(agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate")
agent = OpenAIFunctionsAgent(
llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools
)
agent_executor = AgentExecutor(
agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate"
)

View File

@ -1,5 +1,4 @@
from langchain.document_loaders import CSVLoader
from langchain.tools.retriever import create_retriever_tool
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import FAISS

View File

@ -1,11 +1,12 @@
import os
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers.json import SimpleJsonOutputParser
from elasticsearch import Elasticsearch
from pathlib import Path
from .prompts import DSL_PROMPT
from elasticsearch import Elasticsearch
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers.json import SimpleJsonOutputParser
from .elastic_index_info import get_indices_infos
from .prompts import DSL_PROMPT
es_host = os.environ["ELASTIC_SEARCH_SERVER"]
es_password = os.environ["ELASTIC_PASSWORD"]

View File

@ -1,5 +1,6 @@
from typing import List
def _list_indices(database, include_indices=None, ignore_indices=None) -> List[str]:
all_indices = [
index["index"] for index in database.cat.indices(format="json")

View File

@ -16,6 +16,6 @@ Use the following format:
Question: Question here
ESQuery: Elasticsearch Query formatted as json
"""
""" # noqa: E501
DSL_PROMPT = PromptTemplate.from_template(DEFAULT_DSL_TEMPLATE + PROMPT_SUFFIX)

View File

@ -1,4 +1,5 @@
import os
from elasticsearch import Elasticsearch
es_host = os.environ["ELASTIC_SEARCH_SERVER"]

View File

@ -1,5 +1,4 @@
from elastic_query_generator.chain import chain
if __name__ == "__main__":
print(chain.invoke({"input": "how many customers named Carol"}))

View File

@ -1,40 +1,46 @@
from langchain.pydantic_v1 import BaseModel
import json
from typing import List, Optional
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser
import json
template = """A article will be passed to you. Extract from it all papers that are mentioned by this article.
Do not extract the name of the article itself. If no papers are mentioned that's fine - you don't need to extract any! Just return an empty list.
Do not make up or guess ANY extra information. Only extract what exactly is in the text."""
Do not make up or guess ANY extra information. Only extract what exactly is in the text.""" # noqa: E501
prompt = ChatPromptTemplate.from_messages([("system", template), ("human", "{input}")])
prompt = ChatPromptTemplate.from_messages([
("system", template),
("human", "{input}")
])
# Function output schema
class Paper(BaseModel):
"""Information about papers mentioned."""
title: str
author: Optional[str]
class Info(BaseModel):
"""Information to extract"""
papers: List[Paper]
# Function definition
model = ChatOpenAI()
function = [convert_pydantic_to_openai_function(Info)]
chain = prompt | model.bind(
functions=function, function_call={"name": "Info"}
) | (lambda x: json.loads(x.additional_kwargs['function_call']['arguments'])['papers'])
chain = (
prompt
| model.bind(functions=function, function_call={"name": "Info"})
| (
lambda x: json.loads(x.additional_kwargs["function_call"]["arguments"])[
"papers"
]
)
)
# chain = prompt | model.bind(
# functions=function, function_call={"name": "Info"}

View File

@ -1,14 +1,15 @@
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import Chroma
from hyde.prompts import hyde_prompt
# Example for document loading (from url), splitting, and creating vectostore
'''
"""
# Load
from langchain.document_loaders import WebBaseLoader
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
@ -25,13 +26,13 @@ vectorstore = Chroma.from_documents(documents=all_splits,
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
'''
"""
# Embed a single document as a test
vectorstore = Chroma.from_texts(
["harrison worked at kensho"],
collection_name="rag-chroma",
embedding=OpenAIEmbeddings()
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
@ -48,11 +49,18 @@ model = ChatOpenAI()
# RAG chain
chain = (
RunnableParallel({
# Configure the input, pass it the prompt, pass that to the model, and then the result to the retriever
"context": {"input": RunnablePassthrough()} | hyde_prompt | model | StrOutputParser() | retriever,
"question": RunnablePassthrough()
})
RunnableParallel(
{
# Configure the input, pass it the prompt, pass that to the model,
# and then the result to the retriever
"context": {"input": RunnablePassthrough()}
| hyde_prompt
| model
| StrOutputParser()
| retriever,
"question": RunnablePassthrough(),
}
)
| prompt
| model
| StrOutputParser()

View File

@ -7,7 +7,7 @@ Question: {input}
Passage:"""
sci_fact_template = """Please write a scientific paper passage to support/refute the claim
Claim: {input}
Passage:"""
Passage:""" # noqa: E501
fiqa_template = """Please write a financial article passage to answer the question
Question: {input}
Passage:"""

View File

@ -1,13 +1,13 @@
from typing import List, Optional
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain.chains.openai_functions import create_structured_output_chain
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain.chains.openai_functions import create_structured_output_chain
try:
from pydantic.v1.main import BaseModel, Field
except ImportError:
@ -27,15 +27,18 @@ cypher_validation = CypherQueryCorrector(corrector_schema)
cypher_llm = ChatOpenAI(model_name="gpt-4", temperature=0.0)
qa_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.0)
# Extract entities from text
class Entities(BaseModel):
"""Identifying information about entities."""
names: List[str] = Field(
...,
description="All the person, organization, or business entities that appear in the text",
description="All the person, organization, or business entities that "
"appear in the text",
)
prompt = ChatPromptTemplate.from_messages(
[
(
@ -44,11 +47,13 @@ prompt = ChatPromptTemplate.from_messages(
),
(
"human",
"Use the given format to extract information from the following input: {question}",
"Use the given format to extract information from the following "
"input: {question}",
),
]
)
# Fulltext index query
def map_to_database(entities: Entities) -> Optional[str]:
result = ""
@ -56,16 +61,16 @@ def map_to_database(entities: Entities) -> Optional[str]:
response = graph.query(
"CALL db.index.fulltext.queryNodes('entity', $entity + '*', {limit:1})"
" YIELD node,score RETURN node.name AS result",
{"entity":entity})
{"entity": entity},
)
try:
result += f"{entity} maps to {response[0]['result']} in database\n"
except IndexError:
pass
return result
entity_chain = create_structured_output_chain(
Entities, qa_llm, prompt
)
entity_chain = create_structured_output_chain(Entities, qa_llm, prompt)
# Generate Cypher statement based on natural language input
cypher_template = """Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question:
@ -73,7 +78,7 @@ cypher_template = """Based on the Neo4j graph schema below, write a Cypher query
Entities in the question map to the following database values:
{entities_list}
Question: {question}
Cypher query:"""
Cypher query:""" # noqa: E501
cypher_prompt = ChatPromptTemplate.from_messages(
[
@ -88,7 +93,7 @@ cypher_prompt = ChatPromptTemplate.from_messages(
cypher_response = (
RunnablePassthrough.assign(names=entity_chain)
| RunnablePassthrough.assign(
entities_list=lambda x: map_to_database(x['names']['function']),
entities_list=lambda x: map_to_database(x["names"]["function"]),
schema=lambda _: graph.get_schema,
)
| cypher_prompt
@ -100,13 +105,14 @@ cypher_response = (
response_template = """Based on the the question, Cypher query, and Cypher response, write a natural language response:
Question: {question}
Cypher query: {query}
Cypher Response: {response}"""
Cypher Response: {response}""" # noqa: E501
response_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"Given an input question and Cypher response, convert it to a natural language answer. No pre-amble.",
"Given an input question and Cypher response, convert it to a natural"
" language answer. No pre-amble.",
),
("human", response_template),
]

View File

@ -1,9 +1,9 @@
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
# Connection to Neo4j
graph = Neo4jGraph()
@ -24,7 +24,7 @@ cypher_template = """Based on the Neo4j graph schema below, write a Cypher query
{schema}
Question: {question}
Cypher query:"""
Cypher query:""" # noqa: E501
cypher_prompt = ChatPromptTemplate.from_messages(
[
@ -49,13 +49,14 @@ cypher_response = (
response_template = """Based on the the question, Cypher query, and Cypher response, write a natural language response:
Question: {question}
Cypher query: {query}
Cypher Response: {response}"""
Cypher Response: {response}""" # noqa: E501
response_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"Given an input question and Cypher response, convert it to a natural language answer. No pre-amble.",
"Given an input question and Cypher response, convert it to a "
"natural language answer. No pre-amble.",
),
("human", response_template),
]

View File

@ -1,6 +1,5 @@
from neo4j_generation.chain import chain
if __name__ == "__main__":
text = "Harrison works at LangChain, which is located in San Francisco"
allowed_nodes = ["Person", "Organization", "Location"]

View File

@ -1,11 +1,12 @@
from typing import Optional, List
from typing import List, Optional
from langchain.chains.openai_functions import (
create_structured_output_chain,
)
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.graphs import Neo4jGraph
from langchain.graphs.graph_document import GraphDocument
from langchain.prompts import ChatPromptTemplate
from langchain.schema import Document
from neo4j_generation.utils import (
@ -35,7 +36,7 @@ def get_extraction_chain(
If not provided, there won't be any specific restriction on node labels.
- allowed_rels (Optional[List[str]]): A list of relationship types that are allowed in the knowledge graph.
If not provided, there won't be any specific restriction on relationship types.
"""
""" # noqa: E501
prompt = ChatPromptTemplate.from_messages(
[
(
@ -64,11 +65,12 @@ always use the most complete identifier for that entity throughout the knowledge
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
## 5. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination.
""",
""", # noqa: E501
),
(
"human",
"Use the given format to extract information from the following input: {input}",
"Use the given format to extract information from the "
"following input: {input}",
),
("human", "Tip: Make sure to answer in the correct format"),
]
@ -94,7 +96,7 @@ def chain(
Returns:
str: A confirmation message indicating the completion of the graph construction.
"""
""" # noqa: E501
# Extract graph data using OpenAI functions
extract_chain = get_extraction_chain(allowed_nodes, allowed_relationships)
data = extract_chain.run(text)

View File

@ -1,9 +1,12 @@
from typing import List, Optional
from langchain.graphs.graph_document import (
Node as BaseNode,
)
from langchain.graphs.graph_document import (
Relationship as BaseRelationship,
)
from typing import List, Optional
from langchain.pydantic_v1 import Field, BaseModel
from langchain.pydantic_v1 import BaseModel, Field
class Property(BaseModel):

View File

@ -1,10 +1,11 @@
from langchain.graphs import Neo4jGraph
from langchain.vectorstores import Neo4jVector
from langchain.document_loaders import TextLoader
from langchain.text_splitter import TokenTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from pathlib import Path
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.graphs import Neo4jGraph
from langchain.text_splitter import TokenTextSplitter
from langchain.vectorstores import Neo4jVector
txt_path = Path(__file__).parent / "dune.txt"
graph = Neo4jGraph()

View File

@ -1,5 +1,4 @@
from neo4j_parent.chain import chain
from neo4j_parent.chain import chain
if __name__ == "__main__":
original_query = "What is the plot of the Dune?"

View File

@ -1,8 +1,8 @@
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import Neo4jVector
retrieval_query = """

View File

@ -1,15 +1,15 @@
from typing import List, Tuple
from langchain.schema.messages import HumanMessage, AIMessage
from langchain.chat_models import ChatOpenAI
from langchain.agents import AgentExecutor
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
from langchain.tools.tavily_search import TavilySearchResults
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools.render import format_tool_to_openai_function
from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.pydantic_v1 import BaseModel
from langchain.schema.messages import AIMessage, HumanMessage
from langchain.tools.render import format_tool_to_openai_function
from langchain.tools.tavily_search import TavilySearchResults
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
# Fake Tool
search = TavilySearchAPIWrapper()
@ -18,17 +18,21 @@ tavily_tool = TavilySearchResults(api_wrapper=search)
tools = [tavily_tool]
llm = ChatOpenAI(temperature=0)
prompt = ChatPromptTemplate.from_messages([
("system", "You are very powerful assistant, but bad at calculating lengths of words."),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
llm_with_tools = llm.bind(
functions=[format_tool_to_openai_function(t) for t in tools]
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are very powerful assistant, but bad at calculating lengths of words.",
),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
llm_with_tools = llm.bind(functions=[format_tool_to_openai_function(t) for t in tools])
def _format_chat_history(chat_history: List[Tuple[str, str]]):
buffer = []
for human, ai in chat_history:
@ -37,16 +41,25 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]):
return buffer
agent = {
"input": lambda x: x["input"],
"chat_history": lambda x: _format_chat_history(x['chat_history']),
"agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps']),
} | prompt | llm_with_tools | OpenAIFunctionsAgentOutputParser()
agent = (
{
"input": lambda x: x["input"],
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
"agent_scratchpad": lambda x: format_to_openai_functions(
x["intermediate_steps"]
),
}
| prompt
| llm_with_tools
| OpenAIFunctionsAgentOutputParser()
)
class AgentInput(BaseModel):
input: str
chat_history: List[Tuple[str, str]]
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True).with_types(
input_type=AgentInput
)

191
templates/poetry.lock generated
View File

@ -1,5 +1,144 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "colorama"
version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "docopt"
version = "0.6.2"
description = "Pythonic argument parser, that will make you smile"
optional = false
python-versions = "*"
files = [
{file = "docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491"},
]
[[package]]
name = "exceptiongroup"
version = "1.1.3"
description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
files = [
{file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"},
{file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"},
]
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "iniconfig"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.7"
files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]]
name = "packaging"
version = "23.2"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.7"
files = [
{file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
{file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
]
[[package]]
name = "pastel"
version = "0.2.1"
description = "Bring colors to your terminal."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364"},
{file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"},
]
[[package]]
name = "pluggy"
version = "1.3.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.8"
files = [
{file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"},
{file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"},
]
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "poethepoet"
version = "0.24.1"
description = "A task runner that works well with poetry."
optional = false
python-versions = ">=3.8"
files = [
{file = "poethepoet-0.24.1-py3-none-any.whl", hash = "sha256:3afa44b4fc7327df0dd912eda012604a072af2bb4d243fb0e41e8eca8dabf9ed"},
{file = "poethepoet-0.24.1.tar.gz", hash = "sha256:f5a386387c382f08890c273d13495938208a8ce91ab71536abf388c776c4f366"},
]
[package.dependencies]
pastel = ">=0.2.1,<0.3.0"
tomli = ">=1.2.2"
[package.extras]
poetry-plugin = ["poetry (>=1.0,<2.0)"]
[[package]]
name = "pytest"
version = "7.4.3"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"},
{file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"},
]
[package.dependencies]
colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-watch"
version = "4.2.0"
description = "Local continuous test runner with pytest and watchdog."
optional = false
python-versions = "*"
files = [
{file = "pytest-watch-4.2.0.tar.gz", hash = "sha256:06136f03d5b361718b8d0d234042f7b2f203910d8568f63df2f866b547b3d4b9"},
]
[package.dependencies]
colorama = ">=0.3.3"
docopt = ">=0.4.0"
pytest = ">=2.6.4"
watchdog = ">=0.6.0"
[[package]]
name = "ruff"
version = "0.1.2"
@ -26,7 +165,57 @@ files = [
{file = "ruff-0.1.2.tar.gz", hash = "sha256:afd4785ae060ce6edcd52436d0c197628a918d6d09e3107a892a1bad6a4c6608"},
]
[[package]]
name = "tomli"
version = "2.0.1"
description = "A lil' TOML parser"
optional = false
python-versions = ">=3.7"
files = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
[[package]]
name = "watchdog"
version = "3.0.0"
description = "Filesystem events monitoring"
optional = false
python-versions = ">=3.7"
files = [
{file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:336adfc6f5cc4e037d52db31194f7581ff744b67382eb6021c868322e32eef41"},
{file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a70a8dcde91be523c35b2bf96196edc5730edb347e374c7de7cd20c43ed95397"},
{file = "watchdog-3.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:adfdeab2da79ea2f76f87eb42a3ab1966a5313e5a69a0213a3cc06ef692b0e96"},
{file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2b57a1e730af3156d13b7fdddfc23dea6487fceca29fc75c5a868beed29177ae"},
{file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7ade88d0d778b1b222adebcc0927428f883db07017618a5e684fd03b83342bd9"},
{file = "watchdog-3.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7e447d172af52ad204d19982739aa2346245cc5ba6f579d16dac4bfec226d2e7"},
{file = "watchdog-3.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9fac43a7466eb73e64a9940ac9ed6369baa39b3bf221ae23493a9ec4d0022674"},
{file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8ae9cda41fa114e28faf86cb137d751a17ffd0316d1c34ccf2235e8a84365c7f"},
{file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f70b4aa53bd743729c7475d7ec41093a580528b100e9a8c5b5efe8899592fc"},
{file = "watchdog-3.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4f94069eb16657d2c6faada4624c39464f65c05606af50bb7902e036e3219be3"},
{file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7c5f84b5194c24dd573fa6472685b2a27cc5a17fe5f7b6fd40345378ca6812e3"},
{file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3aa7f6a12e831ddfe78cdd4f8996af9cf334fd6346531b16cec61c3b3c0d8da0"},
{file = "watchdog-3.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:233b5817932685d39a7896b1090353fc8efc1ef99c9c054e46c8002561252fb8"},
{file = "watchdog-3.0.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:13bbbb462ee42ec3c5723e1205be8ced776f05b100e4737518c67c8325cf6100"},
{file = "watchdog-3.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8f3ceecd20d71067c7fd4c9e832d4e22584318983cabc013dbf3f70ea95de346"},
{file = "watchdog-3.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c9d8c8ec7efb887333cf71e328e39cffbf771d8f8f95d308ea4125bf5f90ba64"},
{file = "watchdog-3.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0e06ab8858a76e1219e68c7573dfeba9dd1c0219476c5a44d5333b01d7e1743a"},
{file = "watchdog-3.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:d00e6be486affb5781468457b21a6cbe848c33ef43f9ea4a73b4882e5f188a44"},
{file = "watchdog-3.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:c07253088265c363d1ddf4b3cdb808d59a0468ecd017770ed716991620b8f77a"},
{file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:5113334cf8cf0ac8cd45e1f8309a603291b614191c9add34d33075727a967709"},
{file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:51f90f73b4697bac9c9a78394c3acbbd331ccd3655c11be1a15ae6fe289a8c83"},
{file = "watchdog-3.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:ba07e92756c97e3aca0912b5cbc4e5ad802f4557212788e72a72a47ff376950d"},
{file = "watchdog-3.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d429c2430c93b7903914e4db9a966c7f2b068dd2ebdd2fa9b9ce094c7d459f33"},
{file = "watchdog-3.0.0-py3-none-win32.whl", hash = "sha256:3ed7c71a9dccfe838c2f0b6314ed0d9b22e77d268c67e015450a29036a81f60f"},
{file = "watchdog-3.0.0-py3-none-win_amd64.whl", hash = "sha256:4c9956d27be0bb08fc5f30d9d0179a855436e655f046d288e2bcc11adfae893c"},
{file = "watchdog-3.0.0-py3-none-win_ia64.whl", hash = "sha256:5d9f3a10e02d7371cd929b5d8f11e87d4bad890212ed3901f9b4d68767bee759"},
{file = "watchdog-3.0.0.tar.gz", hash = "sha256:4d98a320595da7a7c5a18fc48cb633c2e73cda78f93cac2ef42d42bf609a33f9"},
]
[package.extras]
watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "7a86271e260f3ac3de6446960b301dc6c854f9d7f9544774d81eaa13be511838"
content-hash = "e00055a76b5e7a5dd6afd6c65073084a9a0f988a8d30e24be2606c048bd25686"

View File

@ -9,14 +9,19 @@ readme = "README.md"
python = "^3.10"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
# dev, test, lint, typing
[tool.poetry.group.dev.dependencies]
poethepoet = "^0.24.1"
pytest-watch = "^4.2.0"
[tool.poetry.group.test.dependencies]
pytest = "^7.4.3"
[tool.poetry.group.lint.dependencies]
ruff = "^0.1"
[tool.poetry.group.typing.dependencies]
[tool.ruff]
select = [
@ -24,3 +29,14 @@ select = [
"F", # pyflakes
"I", # isort
]
[tool.poe.tasks]
test = "poetry run pytest"
watch = "poetry run ptw"
lint = "poetry run ruff ."
format = "poetry run ruff . --fix"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -1,28 +1,31 @@
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOllama
from langchain.prompts import ChatPromptTemplate
from langchain.embeddings import GPT4AllEmbeddings
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
# Load
from langchain.document_loaders import WebBaseLoader
from langchain.embeddings import GPT4AllEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
data = loader.load()
# Split
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(data)
# Add to vectorDB
vectorstore = Chroma.from_documents(documents=all_splits,
collection_name="rag-private",
embedding=GPT4AllEmbeddings(),
)
vectorstore = Chroma.from_documents(
documents=all_splits,
collection_name="rag-private",
embedding=GPT4AllEmbeddings(),
)
retriever = vectorstore.as_retriever()
# Prompt
# Prompt
# Optionally, pull from the Hub
# from langchain import hub
# prompt = hub.pull("rlm/rag-prompt")

View File

@ -1,9 +1,8 @@
from operator import itemgetter
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import Chroma
# Example for document loading (from url), splitting, and creating vectostore

View File

@ -1,15 +1,21 @@
import os
from typing import Tuple, List
from pydantic import BaseModel
from operator import itemgetter
from langchain.vectorstores import Pinecone
from typing import List, Tuple
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import format_document, AIMessage, HumanMessage
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AIMessage, HumanMessage, format_document
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableBranch, RunnableLambda, RunnableMap
from langchain.schema.runnable import (
RunnableBranch,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
)
from langchain.vectorstores import Pinecone
from pydantic import BaseModel
if os.environ.get("PINECONE_API_KEY", None) is None:
raise Exception("Missing `PINECONE_API_KEY` environment variable.")
@ -44,7 +50,7 @@ _template = """Given the following conversation and a follow up question, rephra
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
Standalone question:""" # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
# RAG answer synthesis prompt
@ -52,18 +58,25 @@ template = """Answer the question based only on the following context:
<context>
{context}
</context>"""
ANSWER_PROMPT = ChatPromptTemplate.from_messages([
("system",template),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{question}")
])
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
[
("system", template),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{question}"),
]
)
# Conversational Retrieval Chain
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
buffer = []
for human, ai in chat_history:
@ -71,6 +84,7 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
buffer.append(AIMessage(content=ai))
return buffer
# User input
class ChatHistory(BaseModel):
chat_history: List[Tuple[str, str]]
@ -78,24 +92,28 @@ class ChatHistory(BaseModel):
_search_query = RunnableBranch(
# If input includes chat_history, we condense it with the follow-up question
(
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
run_name="HasChatHistoryCheck"
), # Condense follow-up question and chat into a standalone_question
RunnablePassthrough.assign(
chat_history=lambda x: _format_chat_history(x['chat_history'])
) | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),
),
# Else, we have no chat history, so just pass through the question
RunnableLambda(itemgetter("question"))
# If input includes chat_history, we condense it with the follow-up question
(
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
run_name="HasChatHistoryCheck"
), # Condense follow-up question and chat into a standalone_question
RunnablePassthrough.assign(
chat_history=lambda x: _format_chat_history(x["chat_history"])
)
| CONDENSE_QUESTION_PROMPT
| ChatOpenAI(temperature=0)
| StrOutputParser(),
),
# Else, we have no chat history, so just pass through the question
RunnableLambda(itemgetter("question")),
)
)
_inputs = RunnableMap({
"question": lambda x: x["question"],
"chat_history": lambda x: _format_chat_history(x['chat_history']),
"context": _search_query | retriever | _combine_documents
}).with_types(input_type=ChatHistory)
_inputs = RunnableMap(
{
"question": lambda x: x["question"],
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
"context": _search_query | retriever | _combine_documents,
}
).with_types(input_type=ChatHistory)
chain = _inputs | ANSWER_PROMPT | ChatOpenAI() | StrOutputParser()

View File

@ -1,8 +1,9 @@
import os
from langchain.document_loaders import JSONLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores.elasticsearch import ElasticsearchStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
import os
from langchain.vectorstores.elasticsearch import ElasticsearchStore
ELASTIC_CLOUD_ID = os.getenv("ELASTIC_CLOUD_ID")
ELASTIC_USERNAME = os.getenv("ELASTIC_USERNAME", "elastic")

View File

@ -23,7 +23,9 @@ if __name__ == "__main__":
"question": follow_up_question,
"chat_history": [
"What is the nasa sales team?",
"The sales team of NASA consists of Laura Martinez, the Area Vice-President of North America, and Gary Johnson, the Area Vice-President of South America. (Sales Organization Overview)",
"The sales team of NASA consists of Laura Martinez, the Area "
"Vice-President of North America, and Gary Johnson, the Area "
"Vice-President of South America. (Sales Organization Overview)",
],
}
)

View File

@ -1,13 +1,15 @@
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableMap
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores.elasticsearch import ElasticsearchStore
from langchain.schema import format_document
from typing import Tuple, List
from operator import itemgetter
from .prompts import CONDENSE_QUESTION_PROMPT, LLM_CONTEXT_PROMPT, DOCUMENT_PROMPT
from typing import List, Tuple
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import format_document
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableMap, RunnablePassthrough
from langchain.vectorstores.elasticsearch import ElasticsearchStore
from .connection import es_connection_details
from .prompts import CONDENSE_QUESTION_PROMPT, DOCUMENT_PROMPT, LLM_CONTEXT_PROMPT
# Setup connecting to Elasticsearch
vectorstore = ElasticsearchStore(

View File

@ -6,7 +6,7 @@ condense_question_prompt_template = """Given the following conversation and a fo
Chat History:
{chat_history}
Follow Up Input: {question}
"""
""" # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(
condense_question_prompt_template
)
@ -23,7 +23,7 @@ If you don't know the answer, just say that you don't know, don't try to make up
{context}
----
Question: {question}
"""
""" # noqa: E501
LLM_CONTEXT_PROMPT = ChatPromptTemplate.from_template(llm_context_prompt_template)

View File

@ -1,8 +1,8 @@
import pinecone
from langchain.vectorstores import Pinecone
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Pinecone
pinecone.init(api_key="...",environment="...")
pinecone.init(api_key="...", environment="...")
all_documents = {
"doc1": "Climate change and economic impact.",
@ -14,7 +14,9 @@ all_documents = {
"doc7": "Climate change: The science and models.",
"doc8": "Global warming: A subset of climate change.",
"doc9": "How climate change affects daily weather.",
"doc10": "The history of climate change activism."
"doc10": "The history of climate change activism.",
}
Pinecone.from_texts(list(all_documents.values()), OpenAIEmbeddings(), index_name='rag-fusion')
Pinecone.from_texts(
list(all_documents.values()), OpenAIEmbeddings(), index_name="rag-fusion"
)

View File

@ -1,5 +1,4 @@
from rag_fusion.chain import chain
from rag_fusion.chain import chain
if __name__ == "__main__":
original_query = "impact of climate change"

View File

@ -1,11 +1,11 @@
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain import hub
import pinecone
from langchain.vectorstores import Pinecone
from langchain import hub
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.load import dumps, loads
from langchain.schema.output_parser import StrOutputParser
from langchain.vectorstores import Pinecone
def reciprocal_rank_fusion(results: list[list], k=60):
fused_scores = {}
@ -15,19 +15,29 @@ def reciprocal_rank_fusion(results: list[list], k=60):
doc_str = dumps(doc)
if doc_str not in fused_scores:
fused_scores[doc_str] = 0
previous_score = fused_scores[doc_str]
fused_scores[doc_str] += 1 / (rank + k)
reranked_results = [(loads(doc), score) for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)]
return reranked_results
reranked_results = [
(loads(doc), score)
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
]
return reranked_results
pinecone.init(api_key="...", environment="...")
prompt = hub.pull('langchain-ai/rag-fusion-query-generation')
prompt = hub.pull("langchain-ai/rag-fusion-query-generation")
generate_queries = prompt | ChatOpenAI(temperature=0) | StrOutputParser() | (lambda x: x.split("\n"))
generate_queries = (
prompt | ChatOpenAI(temperature=0) | StrOutputParser() | (lambda x: x.split("\n"))
)
vectorstore = Pinecone.from_existing_index("rag-fusion", OpenAIEmbeddings())
retriever = vectorstore.as_retriever()
chain = {"original_query": lambda x: x} | generate_queries | retriever.map() | reciprocal_rank_fusion
chain = (
{"original_query": lambda x: x}
| generate_queries
| retriever.map()
| reciprocal_rank_fusion
)

View File

@ -1,13 +1,12 @@
import os
import pinecone
from operator import itemgetter
from langchain.vectorstores import Pinecone
from langchain.prompts import ChatPromptTemplate
import os
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema.output_parser import StrOutputParser
from langchain.prompts import ChatPromptTemplate
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import Pinecone
if os.environ.get("PINECONE_API_KEY", None) is None:
raise Exception("Missing `PINECONE_API_KEY` environment variable.")

View File

@ -1,14 +1,13 @@
import os
import pinecone
from operator import itemgetter
from langchain.vectorstores import Pinecone
from langchain.prompts import ChatPromptTemplate
import os
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema.output_parser import StrOutputParser
from langchain.prompts import ChatPromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import Pinecone
if os.environ.get("PINECONE_API_KEY", None) is None:
raise Exception("Missing `PINECONE_API_KEY` environment variable.")
@ -38,7 +37,7 @@ PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX", "langchain-test")
vectorstore = Pinecone.from_existing_index(PINECONE_INDEX_NAME, OpenAIEmbeddings())
# Get k=10 docs
retriever = vectorstore.as_retriever(search_kwargs={"k":10})
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
# Re-rank
compressor = CohereRerank()
@ -56,7 +55,9 @@ prompt = ChatPromptTemplate.from_template(template)
# RAG
model = ChatOpenAI()
chain = (
RunnableParallel({"context": compression_retriever, "question": RunnablePassthrough()})
RunnableParallel(
{"context": compression_retriever, "question": RunnablePassthrough()}
)
| prompt
| model
| StrOutputParser()

View File

@ -1,12 +1,11 @@
import os
import pinecone
from operator import itemgetter
from langchain.vectorstores import Pinecone
from langchain.prompts import ChatPromptTemplate
import os
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import Pinecone
if os.environ.get("PINECONE_API_KEY", None) is None:
raise Exception("Missing `PINECONE_API_KEY` environment variable.")

View File

@ -1,33 +1,36 @@
# Load
import uuid
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from unstructured.partition.pdf import partition_pdf
from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.schema.document import Document
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.storage import InMemoryStore
from langchain.vectorstores import Chroma
from unstructured.partition.pdf import partition_pdf
# Path to docs
path = "docs"
raw_pdf_elements = partition_pdf(filename=path+"LLaMA2.pdf",
# Unstructured first finds embedded image blocks
extract_images_in_pdf=False,
# Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
# Titles are any sub-section of the document
infer_table_structure=True,
# Post processing to aggregate text once we have the title
chunking_strategy="by_title",
# Chunking params to aggregate text blocks
# Attempt to create a new chunk 3800 chars
# Attempt to keep chunks > 2000 chars
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000,
image_output_dir_path=path)
raw_pdf_elements = partition_pdf(
filename=path + "LLaMA2.pdf",
# Unstructured first finds embedded image blocks
extract_images_in_pdf=False,
# Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
# Titles are any sub-section of the document
infer_table_structure=True,
# Post processing to aggregate text once we have the title
chunking_strategy="by_title",
# Chunking params to aggregate text blocks
# Attempt to create a new chunk 3800 chars
# Attempt to keep chunks > 2000 chars
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000,
image_output_dir_path=path,
)
# Categorize by type
tables = []
@ -40,26 +43,23 @@ for element in raw_pdf_elements:
# Summarize
prompt_text="""You are an assistant tasked with summarizing tables and text. \
prompt_text = """You are an assistant tasked with summarizing tables and text. \
Give a concise summary of the table or text. Table or text chunk: {element} """
prompt = ChatPromptTemplate.from_template(prompt_text)
model = ChatOpenAI(temperature=0,model="gpt-4")
summarize_chain = {"element": lambda x:x} | prompt | model | StrOutputParser()
prompt = ChatPromptTemplate.from_template(prompt_text)
model = ChatOpenAI(temperature=0, model="gpt-4")
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
# Apply
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})
# To save time / cost, only do text summaries if chunk sizes are large
# text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})
# We can just assign text_summaries to the raw texts
# We can just assign text_summaries to the raw texts
text_summaries = texts
# Use multi vector retriever
# The vectorstore to use to index the child chunks
vectorstore = Chroma(
collection_name="summaries",
embedding_function=OpenAIEmbeddings()
)
vectorstore = Chroma(collection_name="summaries", embedding_function=OpenAIEmbeddings())
# The storage layer for the parent documents
store = InMemoryStore()
@ -67,20 +67,26 @@ id_key = "doc_id"
# The retriever (empty to start)
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)
# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [Document(page_content=s,metadata={id_key: doc_ids[i]}) for i, s in enumerate(text_summaries)]
summary_texts = [
Document(page_content=s, metadata={id_key: doc_ids[i]})
for i, s in enumerate(text_summaries)
]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))
# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [Document(page_content=s,metadata={id_key: table_ids[i]}) for i, s in enumerate(table_summaries)]
summary_tables = [
Document(page_content=s, metadata={id_key: table_ids[i]})
for i, s in enumerate(table_summaries)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))
@ -90,16 +96,16 @@ retriever.docstore.mset(list(zip(table_ids, tables)))
template = """Answer the question based only on the following context, which can include text and tables:
{context}
Question: {question}
"""
""" # noqa: E501
prompt = ChatPromptTemplate.from_template(template)
# LLM
model = ChatOpenAI(temperature=0,model="gpt-4")
model = ChatOpenAI(temperature=0, model="gpt-4")
# RAG pipeline
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| model
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)
)

View File

@ -1,13 +1,12 @@
import os
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from supabase.client import create_client
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores.supabase import SupabaseVectorStore
from supabase.client import create_client
supabase_url = os.environ.get("SUPABASE_URL")
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
@ -19,7 +18,7 @@ vectorstore = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="documents",
query_name="match_documents"
query_name="match_documents",
)
retriever = vectorstore.as_retriever()

View File

@ -1,5 +1,4 @@
from rewrite_retrieve_read.chain import chain
if __name__ == "__main__":
chain.invoke("man that sam bankman fried trial was crazy! what is langchain?")

View File

@ -1,9 +1,8 @@
from operator import itemgetter
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.schema.runnable import RunnablePassthrough
from langchain.utilities import DuckDuckGoSearchAPIWrapper
template = """Answer the users question based only on the following context:

View File

@ -1,15 +1,12 @@
import os
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms.openai import OpenAI
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from langchain.chains.query_constructor.base import AttributeInfo
from supabase.client import create_client
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores.supabase import SupabaseVectorStore
from supabase.client import create_client
supabase_url = os.environ.get("SUPABASE_URL")
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
@ -21,7 +18,7 @@ vectorstore = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="documents",
query_name="match_documents"
query_name="match_documents",
)
# Adjust this based on the metadata you store in the `metadata` JSON column
@ -51,14 +48,7 @@ document_content_description = "Brief summary of a movie"
llm = OpenAI(temperature=0)
retriever = SelfQueryRetriever.from_llm(
llm,
vectorstore,
document_content_description,
metadata_field_info,
verbose=True
llm, vectorstore, document_content_description, metadata_field_info, verbose=True
)
chain = (
RunnableParallel({"query": RunnablePassthrough()})
| retriever
)
chain = RunnableParallel({"query": RunnablePassthrough()}) | retriever

View File

@ -1,23 +1,26 @@
from pathlib import Path
from langchain.llms import Replicate
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.prompts import ChatPromptTemplate
from langchain.utilities import SQLDatabase
# make sure to set REPLICATE_API_TOKEN in your environment
# use llama-2-13b model in replicate
replicate_id = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
replicate_id = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d" # noqa: E501
llm = Replicate(
model=replicate_id,
model_kwargs={"temperature": 0.01, "max_length": 500, "top_p": 1},
)
from pathlib import Path
from langchain.utilities import SQLDatabase
db_path = Path(__file__).parent / "nba_roster.db"
rel = db_path.relative_to(Path.cwd())
db_string = f"sqlite:///{rel}"
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0)
def get_schema(_):
return db.get_table_info()
@ -30,7 +33,7 @@ template_query = """Based on the table schema below, write a SQL query that woul
{schema}
Question: {question}
SQL Query:"""
SQL Query:""" # noqa: E501
prompt = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
@ -50,13 +53,14 @@ template_response = """Based on the table schema below, question, sql query, and
Question: {question}
SQL Query: {query}
SQL Response: {response}"""
SQL Response: {response}""" # noqa: E501
prompt_response = ChatPromptTemplate.from_messages(
[
(
"system",
"Given an input question and SQL response, convert it to a natural language answer. No pre-amble.",
"Given an input question and SQL response, convert it to a natural "
"language answer. No pre-amble.",
),
("human", template_response),
]

View File

@ -1,11 +1,15 @@
from langchain.llms import LlamaCpp
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
# Get LLM
import os
from pathlib import Path
import requests
from langchain.llms import LlamaCpp
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain.utilities import SQLDatabase
# File name and URL
file_name = "mistral-7b-instruct-v0.1.Q4_K_M.gguf"
url = "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf"
@ -15,7 +19,7 @@ if not os.path.exists(file_name):
# Download the file
response = requests.get(url)
response.raise_for_status() # Raise an exception for HTTP errors
with open(file_name, 'wb') as f:
with open(file_name, "wb") as f:
f.write(response.content)
print(f"'{file_name}' has been downloaded.")
else:
@ -24,23 +28,27 @@ else:
# Add the LLM downloaded from HF
model_path = file_name
n_gpu_layers = 1 # Metal set to 1 is enough.
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.
# Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.
n_batch = 512
llm = LlamaCpp(
model_path=model_path,
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
n_ctx=2048,
f16_kv=True, # MUST set to True, otherwise you will run into problem after a couple of calls
# f16_kv MUST set to True
# otherwise you will run into problem after a couple of calls
f16_kv=True,
verbose=True,
)
from pathlib import Path
from langchain.utilities import SQLDatabase
db_path = Path(__file__).parent / "nba_roster.db"
rel = db_path.relative_to(Path.cwd())
db_string = f"sqlite:///{rel}"
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0)
def get_schema(_):
return db.get_table_info()
@ -48,39 +56,43 @@ def get_schema(_):
def run_query(query):
return db.run(query)
# Prompt
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}
Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_messages([
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
MessagesPlaceholder(variable_name="history"),
("human", template)
])
SQL Query:""" # noqa: E501
prompt = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
MessagesPlaceholder(variable_name="history"),
("human", template),
]
)
memory = ConversationBufferMemory(return_messages=True)
# Chain to query with memory
from langchain.schema.runnable import RunnableLambda
# Chain to query with memory
sql_chain = (
RunnablePassthrough.assign(
schema=get_schema,
history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"])
)| prompt
schema=get_schema,
history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"]),
)
| prompt
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
)
def save(input_output):
output = {"output": input_output.pop("output")}
memory.save_context(input_output, output)
return output['output']
return output["output"]
sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save
# Chain to answer
@ -89,18 +101,24 @@ template = """Based on the table schema below, question, sql query, and sql resp
Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_messages([
("system", "Given an input question and SQL response, convert it to a natural language answer. No pre-amble."),
("human", template)
])
SQL Response: {response}""" # noqa: E501
prompt_response = ChatPromptTemplate.from_messages(
[
(
"system",
"Given an input question and SQL response, convert it to a natural "
"language answer. No pre-amble.",
),
("human", template),
]
)
chain = (
RunnablePassthrough.assign(query=sql_response_memory)
RunnablePassthrough.assign(query=sql_response_memory)
| RunnablePassthrough.assign(
schema=get_schema,
response=lambda x: db.run(x["query"]),
)
| prompt_response
| prompt_response
| llm
)

View File

@ -1,58 +1,67 @@
from pathlib import Path
from langchain.chat_models import ChatOllama
from langchain.prompts import ChatPromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain.utilities import SQLDatabase
# Add the LLM downloaded from Ollama
ollama_llm = "llama2:13b-chat"
llm = ChatOllama(model=ollama_llm)
from pathlib import Path
from langchain.utilities import SQLDatabase
db_path = Path(__file__).parent / "nba_roster.db"
rel = db_path.relative_to(Path.cwd())
db_string = f"sqlite:///{rel}"
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0)
def get_schema(_):
return db.get_table_info()
def run_query(query):
return db.run(query)
# Prompt
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}
Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_messages([
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
MessagesPlaceholder(variable_name="history"),
("human", template)
])
SQL Query:""" # noqa: E501
prompt = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
MessagesPlaceholder(variable_name="history"),
("human", template),
]
)
memory = ConversationBufferMemory(return_messages=True)
# Chain to query with memory
from langchain.schema.runnable import RunnableLambda
# Chain to query with memory
sql_chain = (
RunnablePassthrough.assign(
schema=get_schema,
history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"])
)| prompt
schema=get_schema,
history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"]),
)
| prompt
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
)
def save(input_output):
output = {"output": input_output.pop("output")}
memory.save_context(input_output, output)
return output['output']
return output["output"]
sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save
# Chain to answer
@ -61,18 +70,24 @@ template = """Based on the table schema below, question, sql query, and sql resp
Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_messages([
("system", "Given an input question and SQL response, convert it to a natural language answer. No pre-amble."),
("human", template)
])
SQL Response: {response}""" # noqa: E501
prompt_response = ChatPromptTemplate.from_messages(
[
(
"system",
"Given an input question and SQL response, convert it to a natural "
"language answer. No pre-amble.",
),
("human", template),
]
)
chain = (
RunnablePassthrough.assign(query=sql_response_memory)
RunnablePassthrough.assign(query=sql_response_memory)
| RunnablePassthrough.assign(
schema=get_schema,
response=lambda x: db.run(x["query"]),
)
| prompt_response
| prompt_response
| llm
)

View File

@ -1,5 +1,4 @@
from stepback_qa_prompting.chain import chain
if __name__ == "__main__":
chain.invoke({"question": "was chatgpt around while trump was president?"})

View File

@ -4,9 +4,9 @@ from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda
from langchain.utilities import DuckDuckGoSearchAPIWrapper
search = DuckDuckGoSearchAPIWrapper(max_results=4)
def retriever(query):
return search.run(query)
@ -15,11 +15,11 @@ def retriever(query):
examples = [
{
"input": "Could the members of The Police perform lawful arrests?",
"output": "what can the members of The Police do?"
"output": "what can the members of The Police do?",
},
{
"input": "Jan Sindels was born in what country?",
"output": "what is Jan Sindels personal history?"
"input": "Jan Sindels was born in what country?",
"output": "what is Jan Sindels personal history?",
},
]
# We now transform these to example messages
@ -34,13 +34,20 @@ few_shot_prompt = FewShotChatMessagePromptTemplate(
examples=examples,
)
prompt = ChatPromptTemplate.from_messages([
("system", """You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer. Here are a few examples:"""),
# Few shot examples
few_shot_prompt,
# New question
("user", "{question}"),
])
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an expert at world knowledge. Your task is to step back "
"and paraphrase a question to a more generic step-back question, which "
"is easier to answer. Here are a few examples:",
),
# Few shot examples
few_shot_prompt,
# New question
("user", "{question}"),
]
)
question_gen = prompt | ChatOpenAI(temperature=0) | StrOutputParser()
@ -50,16 +57,19 @@ response_prompt_template = """You are an expert of world knowledge. I am going t
{step_back_context}
Original Question: {question}
Answer:"""
Answer:""" # noqa: E501
response_prompt = ChatPromptTemplate.from_template(response_prompt_template)
chain = {
# Retrieve context using the normal question
"normal_context": RunnableLambda(lambda x: x['question']) | retriever,
# Retrieve context using the step-back question
"step_back_context": question_gen | retriever,
# Pass on the question
"question": lambda x: x["question"]
} | response_prompt | ChatOpenAI(temperature=0) | StrOutputParser()
chain = (
{
# Retrieve context using the normal question
"normal_context": RunnableLambda(lambda x: x["question"]) | retriever,
# Retrieve context using the step-back question
"step_back_context": question_gen | retriever,
# Pass on the question
"question": lambda x: x["question"],
}
| response_prompt
| ChatOpenAI(temperature=0)
| StrOutputParser()
)

View File

@ -1,14 +1,19 @@
from langchain.chat_models import ChatAnthropic
from langchain.tools.render import render_text_description
from langchain.agents.format_scratchpad import format_xml
from langchain.agents import AgentExecutor
from langchain.retrievers.you import YouRetriever
from langchain.agents.agent_toolkits.conversational_retrieval.tool import create_retriever_tool
from langchain.pydantic_v1 import BaseModel
from xml_agent.prompts import conversational_prompt, parse_output
from langchain.schema import AIMessage, HumanMessage
from typing import List, Tuple
from langchain.agents import AgentExecutor
from langchain.agents.agent_toolkits.conversational_retrieval.tool import (
create_retriever_tool,
)
from langchain.agents.format_scratchpad import format_xml
from langchain.chat_models import ChatAnthropic
from langchain.pydantic_v1 import BaseModel
from langchain.retrievers.you import YouRetriever
from langchain.schema import AIMessage, HumanMessage
from langchain.tools.render import render_text_description
from xml_agent.prompts import conversational_prompt, parse_output
def _format_chat_history(chat_history: List[Tuple[str, str]]):
buffer = []
for human, ai in chat_history:
@ -21,7 +26,9 @@ model = ChatAnthropic(model="claude-2")
# Fake Tool
retriever = YouRetriever(k=5)
retriever_tool = create_retriever_tool(retriever, "search", "Use this to search for current events.")
retriever_tool = create_retriever_tool(
retriever, "search", "Use this to search for current events."
)
tools = [retriever_tool]
@ -31,18 +38,25 @@ prompt = conversational_prompt.partial(
)
llm_with_stop = model.bind(stop=["</tool_input>"])
agent = {
"question": lambda x: x["question"],
"agent_scratchpad": lambda x: format_xml(x['intermediate_steps']),
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
} | prompt | llm_with_stop | parse_output
agent = (
{
"question": lambda x: x["question"],
"agent_scratchpad": lambda x: format_xml(x["intermediate_steps"]),
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
}
| prompt
| llm_with_stop
| parse_output
)
class AgentInput(BaseModel):
question: str
chat_history: List[Tuple[str, str]]
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True).with_types(
input_type=AgentInput
)
agent_executor = AgentExecutor(
agent=agent, tools=tools, verbose=True, handle_parsing_errors=True
).with_types(input_type=AgentInput)
agent_executor = agent_executor | (lambda x: x["output"])

View File

@ -27,14 +27,16 @@ Assistant: <tool>search</tool><tool_input>weather in SF</tool_input>
It is 64 degress in SF
Begin!"""
Begin!""" # noqa: E501
conversational_prompt = ChatPromptTemplate.from_messages([
("system", template),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{question}"),
("ai", "{agent_scratchpad}")
])
conversational_prompt = ChatPromptTemplate.from_messages(
[
("system", template),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{question}"),
("ai", "{agent_scratchpad}"),
]
)
def parse_output(message):
@ -47,4 +49,4 @@ def parse_output(message):
_tool_input = _tool_input.split("</tool_input>")[0]
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
else:
return AgentFinish(return_values={"output": text}, log=text)
return AgentFinish(return_values={"output": text}, log=text)