mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
Format Templates (#12396)
This commit is contained in:
parent
25c98dbba9
commit
4b16601d33
@ -1,2 +1,8 @@
|
||||
lint lint_diff:
|
||||
poetry run ruff .
|
||||
poetry run poe lint
|
||||
|
||||
test:
|
||||
poetry run poe test
|
||||
|
||||
format:
|
||||
poetry run poe format
|
||||
|
@ -1,7 +1,7 @@
|
||||
from langchain.schema.runnable import ConfigurableField
|
||||
|
||||
from .chain import chain
|
||||
from .retriever_agent import executor
|
||||
from .chain import chain
|
||||
|
||||
final_chain = chain.configurable_alternatives(
|
||||
ConfigurableField(id="chain"),
|
||||
|
@ -1,5 +1,5 @@
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
|
||||
from .prompts import answer_prompt
|
||||
|
@ -1,6 +1,7 @@
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
import re
|
||||
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
|
||||
from .agent_scratchpad import _format_docs
|
||||
|
||||
|
||||
@ -14,18 +15,23 @@ def extract_between_tags(tag: str, string: str, strip: bool = True) -> str:
|
||||
# Only return the first one
|
||||
return ext_list[0]
|
||||
|
||||
|
||||
def parse_output(outputs):
|
||||
partial_completion = outputs["partial_completion"]
|
||||
steps = outputs["intermediate_steps"]
|
||||
search_query = extract_between_tags('search_query', partial_completion + '</search_query>')
|
||||
search_query = extract_between_tags(
|
||||
"search_query", partial_completion + "</search_query>"
|
||||
)
|
||||
if search_query is None:
|
||||
docs = []
|
||||
str_output = ""
|
||||
for action, observation in steps:
|
||||
docs.extend(observation)
|
||||
str_output += action.log
|
||||
str_output += '</search_query>' + _format_docs(observation)
|
||||
str_output += "</search_query>" + _format_docs(observation)
|
||||
str_output += partial_completion
|
||||
return AgentFinish({"docs": docs, "output": str_output}, log=partial_completion)
|
||||
else:
|
||||
return AgentAction(tool="search", tool_input=search_query, log=partial_completion)
|
||||
return AgentAction(
|
||||
tool="search", tool_input=search_query, log=partial_completion
|
||||
)
|
||||
|
@ -2,6 +2,6 @@ retrieval_prompt = """{retriever_description} Before beginning to research the u
|
||||
|
||||
After each call to the Search Engine Tool, reflect briefly inside <search_quality></search_quality> tags about whether you now have enough information to answer, or whether more information is needed. If you have all the relevant information, write it in <information></information> tags, WITHOUT actually answering the question. Otherwise, issue a new search.
|
||||
|
||||
Here is the user's question: <question>{query}</question> Remind yourself to make short queries in your scratchpad as you plan out your strategy."""
|
||||
Here is the user's question: <question>{query}</question> Remind yourself to make short queries in your scratchpad as you plan out your strategy.""" # noqa: E501
|
||||
|
||||
answer_prompt = "Here is a user query: <query>{query}</query>. Here is some relevant information: <information>{information}</information>. Please answer the question using the relevant information."
|
||||
answer_prompt = "Here is a user query: <query>{query}</query>. Here is some relevant information: <information>{information}</information>. Please answer the question using the relevant information." # noqa: E501
|
||||
|
@ -3,13 +3,14 @@ from langchain.tools import tool
|
||||
|
||||
# This is used to tell the model how to best use the retriever.
|
||||
|
||||
retriever_description = """You will be asked a question by a human user. You have access to the following tool to help answer the question. <tool_description> Search Engine Tool * The search engine will exclusively search over Wikipedia for pages similar to your query. It returns for each page its title and full page content. Use this tool if you want to get up-to-date and comprehensive information on a topic to help answer queries. Queries should be as atomic as possible -- they only need to address one part of the user's question. For example, if the user's query is "what is the color of a basketball?", your search query should be "basketball". Here's another example: if the user's question is "Who created the first neural network?", your first query should be "neural network". As you can see, these queries are quite short. Think keywords, not phrases. * At any time, you can make a call to the search engine using the following syntax: <search_query>query_word</search_query>. * You'll then get results back in <search_result> tags.</tool_description>"""
|
||||
retriever_description = """You will be asked a question by a human user. You have access to the following tool to help answer the question. <tool_description> Search Engine Tool * The search engine will exclusively search over Wikipedia for pages similar to your query. It returns for each page its title and full page content. Use this tool if you want to get up-to-date and comprehensive information on a topic to help answer queries. Queries should be as atomic as possible -- they only need to address one part of the user's question. For example, if the user's query is "what is the color of a basketball?", your search query should be "basketball". Here's another example: if the user's question is "Who created the first neural network?", your first query should be "neural network". As you can see, these queries are quite short. Think keywords, not phrases. * At any time, you can make a call to the search engine using the following syntax: <search_query>query_word</search_query>. * You'll then get results back in <search_result> tags.</tool_description>""" # noqa: E501
|
||||
|
||||
retriever = WikipediaRetriever()
|
||||
|
||||
# This should be the same as the function name below
|
||||
RETRIEVER_TOOL_NAME = "search"
|
||||
|
||||
|
||||
@tool
|
||||
def search(query):
|
||||
"""Search with the retriever."""
|
||||
|
@ -1,13 +1,13 @@
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableMap
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.schema.runnable import RunnableMap, RunnablePassthrough
|
||||
|
||||
from .retriever import search, RETRIEVER_TOOL_NAME, retriever_description
|
||||
from .prompts import retrieval_prompt
|
||||
from .agent_scratchpad import format_agent_scratchpad
|
||||
from .output_parser import parse_output
|
||||
from .prompts import retrieval_prompt
|
||||
from .retriever import retriever_description, search
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("user", retrieval_prompt),
|
||||
|
@ -1,6 +1,12 @@
|
||||
from anthropic_iterative_search import final_chain
|
||||
|
||||
from anthropic_iterative_search import final_chain
|
||||
|
||||
if __name__ == "__main__":
|
||||
query = "Which movie came out first: Oppenheimer, or Are You There God It's Me Margaret?"
|
||||
print(final_chain.with_config(configurable={"chain": "retrieve"}).invoke({"query": query}))
|
||||
query = (
|
||||
"Which movie came out first: Oppenheimer, or "
|
||||
"Are You There God It's Me Margaret?"
|
||||
)
|
||||
print(
|
||||
final_chain.with_config(configurable={"chain": "retrieve"}).invoke(
|
||||
{"query": query}
|
||||
)
|
||||
)
|
||||
|
@ -1,14 +1,12 @@
|
||||
import os
|
||||
|
||||
import cassio
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import Cassandra
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
from langchain.vectorstores import Cassandra
|
||||
|
||||
use_cassandra = int(os.environ.get("USE_CASSANDRA_CLUSTER", "0"))
|
||||
if use_cassandra:
|
||||
|
@ -1,13 +1,13 @@
|
||||
import os
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
|
||||
def get_cassandra_connection():
|
||||
contact_points = [
|
||||
cp.strip()
|
||||
for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(',')
|
||||
for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(",")
|
||||
if cp.strip()
|
||||
]
|
||||
CASSANDRA_KEYSPACE = os.environ["CASSANDRA_KEYSPACE"]
|
||||
@ -22,6 +22,8 @@ def get_cassandra_connection():
|
||||
else:
|
||||
auth_provider = None
|
||||
|
||||
c_cluster = Cluster(contact_points if contact_points else None, auth_provider=auth_provider)
|
||||
c_cluster = Cluster(
|
||||
contact_points if contact_points else None, auth_provider=auth_provider
|
||||
)
|
||||
session = c_cluster.connect()
|
||||
return (session, CASSANDRA_KEYSPACE)
|
||||
|
@ -1,14 +1,13 @@
|
||||
import os
|
||||
|
||||
import cassio
|
||||
|
||||
from langchain.vectorstores import Cassandra
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from langchain.vectorstores import Cassandra
|
||||
|
||||
use_cassandra = int(os.environ.get("USE_CASSANDRA_CLUSTER", "0"))
|
||||
if use_cassandra:
|
||||
from cassandra_entomology_rag.cassandra_cluster_init import get_cassandra_connection
|
||||
|
||||
session, keyspace = get_cassandra_connection()
|
||||
cassio.init(
|
||||
session=session,
|
||||
@ -22,7 +21,7 @@ else:
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vector_store = Cassandra(
|
||||
session=None,
|
||||
@ -32,16 +31,13 @@ if __name__ == '__main__':
|
||||
)
|
||||
#
|
||||
lines = [
|
||||
l.strip()
|
||||
for l in open("sources.txt").readlines()
|
||||
if l.strip()
|
||||
if l[0] != "#"
|
||||
line.strip()
|
||||
for line in open("sources.txt").readlines()
|
||||
if line.strip()
|
||||
if line[0] != "#"
|
||||
]
|
||||
# deterministic IDs to prevent duplicates on multiple runs
|
||||
ids = [
|
||||
"_".join(l.split(" ")[:2]).lower().replace(":", "")
|
||||
for l in lines
|
||||
]
|
||||
ids = ["_".join(line.split(" ")[:2]).lower().replace(":", "") for line in lines]
|
||||
#
|
||||
vector_store.add_texts(texts=lines, ids=ids)
|
||||
print(f"Done ({len(lines)} lines inserted).")
|
||||
|
@ -1,13 +1,12 @@
|
||||
import os
|
||||
|
||||
import cassio
|
||||
|
||||
import langchain
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema.runnable import RunnableLambda
|
||||
from langchain.cache import CassandraCache
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema.runnable import RunnableLambda
|
||||
|
||||
use_cassandra = int(os.environ.get("USE_CASSANDRA_CLUSTER", "0"))
|
||||
if use_cassandra:
|
||||
|
@ -1,13 +1,13 @@
|
||||
import os
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
|
||||
def get_cassandra_connection():
|
||||
contact_points = [
|
||||
cp.strip()
|
||||
for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(',')
|
||||
for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(",")
|
||||
if cp.strip()
|
||||
]
|
||||
CASSANDRA_KEYSPACE = os.environ["CASSANDRA_KEYSPACE"]
|
||||
@ -22,6 +22,8 @@ def get_cassandra_connection():
|
||||
else:
|
||||
auth_provider = None
|
||||
|
||||
c_cluster = Cluster(contact_points if contact_points else None, auth_provider=auth_provider)
|
||||
c_cluster = Cluster(
|
||||
contact_points if contact_points else None, auth_provider=auth_provider
|
||||
)
|
||||
session = c_cluster.connect()
|
||||
return (session, CASSANDRA_KEYSPACE)
|
||||
|
@ -1,24 +1,25 @@
|
||||
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_experimental.tools import PythonAstREPLTool
|
||||
import pandas as pd
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langsmith import Client
|
||||
from langchain.smith import RunEvalConfig, run_on_dataset
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.tools.retriever import create_retriever_tool
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.tools.retriever import create_retriever_tool
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain_experimental.tools import PythonAstREPLTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
MAIN_DIR = Path(__file__).parents[1]
|
||||
|
||||
pd.set_option('display.max_rows', 20)
|
||||
pd.set_option('display.max_columns', 20)
|
||||
pd.set_option("display.max_rows", 20)
|
||||
pd.set_option("display.max_columns", 20)
|
||||
|
||||
embedding_model = OpenAIEmbeddings()
|
||||
vectorstore = FAISS.load_local(MAIN_DIR / "titanic_data", embedding_model)
|
||||
retriever_tool = create_retriever_tool(vectorstore.as_retriever(), "person_name_search", "Search for a person by name")
|
||||
retriever_tool = create_retriever_tool(
|
||||
vectorstore.as_retriever(), "person_name_search", "Search for a person by name"
|
||||
)
|
||||
|
||||
|
||||
TEMPLATE = """You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
|
||||
@ -41,8 +42,7 @@ For example:
|
||||
|
||||
<question>Who has id 320</question>
|
||||
<logic>Use `python_repl` since even though the question is about a person, you don't know their name so you can't include it.</logic>
|
||||
"""
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class PythonInputs(BaseModel):
|
||||
@ -52,15 +52,24 @@ class PythonInputs(BaseModel):
|
||||
df = pd.read_csv("titanic.csv")
|
||||
template = TEMPLATE.format(dhead=df.head().to_markdown())
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", template),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
("human", "{input}")
|
||||
])
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", template),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
repl = PythonAstREPLTool(locals={"df": df}, name="python_repl",
|
||||
description="Runs code and returns the output of the final line",
|
||||
args_schema=PythonInputs)
|
||||
repl = PythonAstREPLTool(
|
||||
locals={"df": df},
|
||||
name="python_repl",
|
||||
description="Runs code and returns the output of the final line",
|
||||
args_schema=PythonInputs,
|
||||
)
|
||||
tools = [repl, retriever_tool]
|
||||
agent = OpenAIFunctionsAgent(llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools)
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate")
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools
|
||||
)
|
||||
agent_executor = AgentExecutor(
|
||||
agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate"
|
||||
)
|
||||
|
@ -1,5 +1,4 @@
|
||||
from langchain.document_loaders import CSVLoader
|
||||
from langchain.tools.retriever import create_retriever_tool
|
||||
from langchain.indexes import VectorstoreIndexCreator
|
||||
from langchain.vectorstores import FAISS
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
import os
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.output_parsers.json import SimpleJsonOutputParser
|
||||
from elasticsearch import Elasticsearch
|
||||
from pathlib import Path
|
||||
|
||||
from .prompts import DSL_PROMPT
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.output_parsers.json import SimpleJsonOutputParser
|
||||
|
||||
from .elastic_index_info import get_indices_infos
|
||||
from .prompts import DSL_PROMPT
|
||||
|
||||
es_host = os.environ["ELASTIC_SEARCH_SERVER"]
|
||||
es_password = os.environ["ELASTIC_PASSWORD"]
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
|
||||
def _list_indices(database, include_indices=None, ignore_indices=None) -> List[str]:
|
||||
all_indices = [
|
||||
index["index"] for index in database.cat.indices(format="json")
|
||||
|
@ -16,6 +16,6 @@ Use the following format:
|
||||
|
||||
Question: Question here
|
||||
ESQuery: Elasticsearch Query formatted as json
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
DSL_PROMPT = PromptTemplate.from_template(DEFAULT_DSL_TEMPLATE + PROMPT_SUFFIX)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
es_host = os.environ["ELASTIC_SEARCH_SERVER"]
|
||||
|
@ -1,5 +1,4 @@
|
||||
from elastic_query_generator.chain import chain
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(chain.invoke({"input": "how many customers named Carol"}))
|
||||
|
@ -1,40 +1,46 @@
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
|
||||
from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser
|
||||
import json
|
||||
|
||||
|
||||
template = """A article will be passed to you. Extract from it all papers that are mentioned by this article.
|
||||
|
||||
Do not extract the name of the article itself. If no papers are mentioned that's fine - you don't need to extract any! Just return an empty list.
|
||||
|
||||
Do not make up or guess ANY extra information. Only extract what exactly is in the text."""
|
||||
Do not make up or guess ANY extra information. Only extract what exactly is in the text.""" # noqa: E501
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([("system", template), ("human", "{input}")])
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", template),
|
||||
("human", "{input}")
|
||||
])
|
||||
|
||||
# Function output schema
|
||||
class Paper(BaseModel):
|
||||
"""Information about papers mentioned."""
|
||||
|
||||
title: str
|
||||
author: Optional[str]
|
||||
|
||||
|
||||
class Info(BaseModel):
|
||||
"""Information to extract"""
|
||||
|
||||
papers: List[Paper]
|
||||
|
||||
|
||||
# Function definition
|
||||
model = ChatOpenAI()
|
||||
function = [convert_pydantic_to_openai_function(Info)]
|
||||
chain = prompt | model.bind(
|
||||
functions=function, function_call={"name": "Info"}
|
||||
) | (lambda x: json.loads(x.additional_kwargs['function_call']['arguments'])['papers'])
|
||||
chain = (
|
||||
prompt
|
||||
| model.bind(functions=function, function_call={"name": "Info"})
|
||||
| (
|
||||
lambda x: json.loads(x.additional_kwargs["function_call"]["arguments"])[
|
||||
"papers"
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# chain = prompt | model.bind(
|
||||
# functions=function, function_call={"name": "Info"}
|
||||
|
@ -1,14 +1,15 @@
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
|
||||
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
from hyde.prompts import hyde_prompt
|
||||
|
||||
# Example for document loading (from url), splitting, and creating vectostore
|
||||
|
||||
'''
|
||||
"""
|
||||
# Load
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
|
||||
@ -25,13 +26,13 @@ vectorstore = Chroma.from_documents(documents=all_splits,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
)
|
||||
retriever = vectorstore.as_retriever()
|
||||
'''
|
||||
"""
|
||||
|
||||
# Embed a single document as a test
|
||||
vectorstore = Chroma.from_texts(
|
||||
["harrison worked at kensho"],
|
||||
collection_name="rag-chroma",
|
||||
embedding=OpenAIEmbeddings()
|
||||
embedding=OpenAIEmbeddings(),
|
||||
)
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
@ -48,11 +49,18 @@ model = ChatOpenAI()
|
||||
|
||||
# RAG chain
|
||||
chain = (
|
||||
RunnableParallel({
|
||||
# Configure the input, pass it the prompt, pass that to the model, and then the result to the retriever
|
||||
"context": {"input": RunnablePassthrough()} | hyde_prompt | model | StrOutputParser() | retriever,
|
||||
"question": RunnablePassthrough()
|
||||
})
|
||||
RunnableParallel(
|
||||
{
|
||||
# Configure the input, pass it the prompt, pass that to the model,
|
||||
# and then the result to the retriever
|
||||
"context": {"input": RunnablePassthrough()}
|
||||
| hyde_prompt
|
||||
| model
|
||||
| StrOutputParser()
|
||||
| retriever,
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
)
|
||||
| prompt
|
||||
| model
|
||||
| StrOutputParser()
|
||||
|
@ -7,7 +7,7 @@ Question: {input}
|
||||
Passage:"""
|
||||
sci_fact_template = """Please write a scientific paper passage to support/refute the claim
|
||||
Claim: {input}
|
||||
Passage:"""
|
||||
Passage:""" # noqa: E501
|
||||
fiqa_template = """Please write a financial article passage to answer the question
|
||||
Question: {input}
|
||||
Passage:"""
|
||||
|
@ -1,13 +1,13 @@
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
||||
from langchain.chains.openai_functions import create_structured_output_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.graphs import Neo4jGraph
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
||||
from langchain.chains.openai_functions import create_structured_output_chain
|
||||
|
||||
try:
|
||||
from pydantic.v1.main import BaseModel, Field
|
||||
except ImportError:
|
||||
@ -27,15 +27,18 @@ cypher_validation = CypherQueryCorrector(corrector_schema)
|
||||
cypher_llm = ChatOpenAI(model_name="gpt-4", temperature=0.0)
|
||||
qa_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.0)
|
||||
|
||||
|
||||
# Extract entities from text
|
||||
class Entities(BaseModel):
|
||||
"""Identifying information about entities."""
|
||||
|
||||
names: List[str] = Field(
|
||||
...,
|
||||
description="All the person, organization, or business entities that appear in the text",
|
||||
description="All the person, organization, or business entities that "
|
||||
"appear in the text",
|
||||
)
|
||||
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
@ -44,11 +47,13 @@ prompt = ChatPromptTemplate.from_messages(
|
||||
),
|
||||
(
|
||||
"human",
|
||||
"Use the given format to extract information from the following input: {question}",
|
||||
"Use the given format to extract information from the following "
|
||||
"input: {question}",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Fulltext index query
|
||||
def map_to_database(entities: Entities) -> Optional[str]:
|
||||
result = ""
|
||||
@ -56,16 +61,16 @@ def map_to_database(entities: Entities) -> Optional[str]:
|
||||
response = graph.query(
|
||||
"CALL db.index.fulltext.queryNodes('entity', $entity + '*', {limit:1})"
|
||||
" YIELD node,score RETURN node.name AS result",
|
||||
{"entity":entity})
|
||||
{"entity": entity},
|
||||
)
|
||||
try:
|
||||
result += f"{entity} maps to {response[0]['result']} in database\n"
|
||||
except IndexError:
|
||||
pass
|
||||
return result
|
||||
|
||||
entity_chain = create_structured_output_chain(
|
||||
Entities, qa_llm, prompt
|
||||
)
|
||||
|
||||
entity_chain = create_structured_output_chain(Entities, qa_llm, prompt)
|
||||
|
||||
# Generate Cypher statement based on natural language input
|
||||
cypher_template = """Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question:
|
||||
@ -73,7 +78,7 @@ cypher_template = """Based on the Neo4j graph schema below, write a Cypher query
|
||||
Entities in the question map to the following database values:
|
||||
{entities_list}
|
||||
Question: {question}
|
||||
Cypher query:"""
|
||||
Cypher query:""" # noqa: E501
|
||||
|
||||
cypher_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
@ -88,7 +93,7 @@ cypher_prompt = ChatPromptTemplate.from_messages(
|
||||
cypher_response = (
|
||||
RunnablePassthrough.assign(names=entity_chain)
|
||||
| RunnablePassthrough.assign(
|
||||
entities_list=lambda x: map_to_database(x['names']['function']),
|
||||
entities_list=lambda x: map_to_database(x["names"]["function"]),
|
||||
schema=lambda _: graph.get_schema,
|
||||
)
|
||||
| cypher_prompt
|
||||
@ -100,13 +105,14 @@ cypher_response = (
|
||||
response_template = """Based on the the question, Cypher query, and Cypher response, write a natural language response:
|
||||
Question: {question}
|
||||
Cypher query: {query}
|
||||
Cypher Response: {response}"""
|
||||
Cypher Response: {response}""" # noqa: E501
|
||||
|
||||
response_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"Given an input question and Cypher response, convert it to a natural language answer. No pre-amble.",
|
||||
"Given an input question and Cypher response, convert it to a natural"
|
||||
" language answer. No pre-amble.",
|
||||
),
|
||||
("human", response_template),
|
||||
]
|
||||
|
@ -1,9 +1,9 @@
|
||||
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.graphs import Neo4jGraph
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
||||
|
||||
# Connection to Neo4j
|
||||
graph = Neo4jGraph()
|
||||
@ -24,7 +24,7 @@ cypher_template = """Based on the Neo4j graph schema below, write a Cypher query
|
||||
{schema}
|
||||
|
||||
Question: {question}
|
||||
Cypher query:"""
|
||||
Cypher query:""" # noqa: E501
|
||||
|
||||
cypher_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
@ -49,13 +49,14 @@ cypher_response = (
|
||||
response_template = """Based on the the question, Cypher query, and Cypher response, write a natural language response:
|
||||
Question: {question}
|
||||
Cypher query: {query}
|
||||
Cypher Response: {response}"""
|
||||
Cypher Response: {response}""" # noqa: E501
|
||||
|
||||
response_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"Given an input question and Cypher response, convert it to a natural language answer. No pre-amble.",
|
||||
"Given an input question and Cypher response, convert it to a "
|
||||
"natural language answer. No pre-amble.",
|
||||
),
|
||||
("human", response_template),
|
||||
]
|
||||
|
@ -1,6 +1,5 @@
|
||||
from neo4j_generation.chain import chain
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "Harrison works at LangChain, which is located in San Francisco"
|
||||
allowed_nodes = ["Person", "Organization", "Location"]
|
||||
|
@ -1,11 +1,12 @@
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.chains.openai_functions import (
|
||||
create_structured_output_chain,
|
||||
)
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.graphs import Neo4jGraph
|
||||
from langchain.graphs.graph_document import GraphDocument
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import Document
|
||||
|
||||
from neo4j_generation.utils import (
|
||||
@ -35,7 +36,7 @@ def get_extraction_chain(
|
||||
If not provided, there won't be any specific restriction on node labels.
|
||||
- allowed_rels (Optional[List[str]]): A list of relationship types that are allowed in the knowledge graph.
|
||||
If not provided, there won't be any specific restriction on relationship types.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
@ -64,11 +65,12 @@ always use the most complete identifier for that entity throughout the knowledge
|
||||
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
|
||||
## 5. Strict Compliance
|
||||
Adhere to the rules strictly. Non-compliance will result in termination.
|
||||
""",
|
||||
""", # noqa: E501
|
||||
),
|
||||
(
|
||||
"human",
|
||||
"Use the given format to extract information from the following input: {input}",
|
||||
"Use the given format to extract information from the "
|
||||
"following input: {input}",
|
||||
),
|
||||
("human", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
@ -94,7 +96,7 @@ def chain(
|
||||
|
||||
Returns:
|
||||
str: A confirmation message indicating the completion of the graph construction.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
# Extract graph data using OpenAI functions
|
||||
extract_chain = get_extraction_chain(allowed_nodes, allowed_relationships)
|
||||
data = extract_chain.run(text)
|
||||
|
@ -1,9 +1,12 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.graphs.graph_document import (
|
||||
Node as BaseNode,
|
||||
)
|
||||
from langchain.graphs.graph_document import (
|
||||
Relationship as BaseRelationship,
|
||||
)
|
||||
from typing import List, Optional
|
||||
from langchain.pydantic_v1 import Field, BaseModel
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
class Property(BaseModel):
|
||||
|
@ -1,10 +1,11 @@
|
||||
from langchain.graphs import Neo4jGraph
|
||||
from langchain.vectorstores import Neo4jVector
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.text_splitter import TokenTextSplitter
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.graphs import Neo4jGraph
|
||||
from langchain.text_splitter import TokenTextSplitter
|
||||
from langchain.vectorstores import Neo4jVector
|
||||
|
||||
txt_path = Path(__file__).parent / "dune.txt"
|
||||
|
||||
graph = Neo4jGraph()
|
||||
|
@ -1,5 +1,4 @@
|
||||
from neo4j_parent.chain import chain
|
||||
|
||||
from neo4j_parent.chain import chain
|
||||
|
||||
if __name__ == "__main__":
|
||||
original_query = "What is the plot of the Dune?"
|
||||
|
@ -1,8 +1,8 @@
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
|
||||
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
|
||||
from langchain.vectorstores import Neo4jVector
|
||||
|
||||
retrieval_query = """
|
||||
|
@ -1,15 +1,15 @@
|
||||
from typing import List, Tuple
|
||||
from langchain.schema.messages import HumanMessage, AIMessage
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
|
||||
from langchain.tools.tavily_search import TavilySearchResults
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.tools.render import format_tool_to_openai_function
|
||||
from langchain.agents.format_scratchpad import format_to_openai_functions
|
||||
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain.schema.messages import AIMessage, HumanMessage
|
||||
from langchain.tools.render import format_tool_to_openai_function
|
||||
from langchain.tools.tavily_search import TavilySearchResults
|
||||
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
|
||||
|
||||
# Fake Tool
|
||||
search = TavilySearchAPIWrapper()
|
||||
@ -18,17 +18,21 @@ tavily_tool = TavilySearchResults(api_wrapper=search)
|
||||
tools = [tavily_tool]
|
||||
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "You are very powerful assistant, but bad at calculating lengths of words."),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
])
|
||||
|
||||
llm_with_tools = llm.bind(
|
||||
functions=[format_tool_to_openai_function(t) for t in tools]
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are very powerful assistant, but bad at calculating lengths of words.",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
|
||||
llm_with_tools = llm.bind(functions=[format_tool_to_openai_function(t) for t in tools])
|
||||
|
||||
|
||||
def _format_chat_history(chat_history: List[Tuple[str, str]]):
|
||||
buffer = []
|
||||
for human, ai in chat_history:
|
||||
@ -37,16 +41,25 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]):
|
||||
return buffer
|
||||
|
||||
|
||||
agent = {
|
||||
"input": lambda x: x["input"],
|
||||
"chat_history": lambda x: _format_chat_history(x['chat_history']),
|
||||
"agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps']),
|
||||
} | prompt | llm_with_tools | OpenAIFunctionsAgentOutputParser()
|
||||
agent = (
|
||||
{
|
||||
"input": lambda x: x["input"],
|
||||
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
|
||||
"agent_scratchpad": lambda x: format_to_openai_functions(
|
||||
x["intermediate_steps"]
|
||||
),
|
||||
}
|
||||
| prompt
|
||||
| llm_with_tools
|
||||
| OpenAIFunctionsAgentOutputParser()
|
||||
)
|
||||
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
input: str
|
||||
chat_history: List[Tuple[str, str]]
|
||||
|
||||
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True).with_types(
|
||||
input_type=AgentInput
|
||||
)
|
||||
|
191
templates/poetry.lock
generated
191
templates/poetry.lock
generated
@ -1,5 +1,144 @@
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
description = "Cross-platform colored terminal text."
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docopt"
|
||||
version = "0.6.2"
|
||||
description = "Pythonic argument parser, that will make you smile"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.1.3"
|
||||
description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"},
|
||||
{file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
test = ["pytest (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.0.0"
|
||||
description = "brain-dead simple config-ini parsing"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
|
||||
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "23.2"
|
||||
description = "Core utilities for Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
|
||||
{file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pastel"
|
||||
version = "0.2.1"
|
||||
description = "Bring colors to your terminal."
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
files = [
|
||||
{file = "pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364"},
|
||||
{file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.3.0"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"},
|
||||
{file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "tox"]
|
||||
testing = ["pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "poethepoet"
|
||||
version = "0.24.1"
|
||||
description = "A task runner that works well with poetry."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "poethepoet-0.24.1-py3-none-any.whl", hash = "sha256:3afa44b4fc7327df0dd912eda012604a072af2bb4d243fb0e41e8eca8dabf9ed"},
|
||||
{file = "poethepoet-0.24.1.tar.gz", hash = "sha256:f5a386387c382f08890c273d13495938208a8ce91ab71536abf388c776c4f366"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pastel = ">=0.2.1,<0.3.0"
|
||||
tomli = ">=1.2.2"
|
||||
|
||||
[package.extras]
|
||||
poetry-plugin = ["poetry (>=1.0,<2.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "7.4.3"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"},
|
||||
{file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
||||
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
|
||||
iniconfig = "*"
|
||||
packaging = "*"
|
||||
pluggy = ">=0.12,<2.0"
|
||||
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-watch"
|
||||
version = "4.2.0"
|
||||
description = "Local continuous test runner with pytest and watchdog."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pytest-watch-4.2.0.tar.gz", hash = "sha256:06136f03d5b361718b8d0d234042f7b2f203910d8568f63df2f866b547b3d4b9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = ">=0.3.3"
|
||||
docopt = ">=0.4.0"
|
||||
pytest = ">=2.6.4"
|
||||
watchdog = ">=0.6.0"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.1.2"
|
||||
@ -26,7 +165,57 @@ files = [
|
||||
{file = "ruff-0.1.2.tar.gz", hash = "sha256:afd4785ae060ce6edcd52436d0c197628a918d6d09e3107a892a1bad6a4c6608"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "watchdog"
|
||||
version = "3.0.0"
|
||||
description = "Filesystem events monitoring"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:336adfc6f5cc4e037d52db31194f7581ff744b67382eb6021c868322e32eef41"},
|
||||
{file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a70a8dcde91be523c35b2bf96196edc5730edb347e374c7de7cd20c43ed95397"},
|
||||
{file = "watchdog-3.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:adfdeab2da79ea2f76f87eb42a3ab1966a5313e5a69a0213a3cc06ef692b0e96"},
|
||||
{file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2b57a1e730af3156d13b7fdddfc23dea6487fceca29fc75c5a868beed29177ae"},
|
||||
{file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7ade88d0d778b1b222adebcc0927428f883db07017618a5e684fd03b83342bd9"},
|
||||
{file = "watchdog-3.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7e447d172af52ad204d19982739aa2346245cc5ba6f579d16dac4bfec226d2e7"},
|
||||
{file = "watchdog-3.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9fac43a7466eb73e64a9940ac9ed6369baa39b3bf221ae23493a9ec4d0022674"},
|
||||
{file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8ae9cda41fa114e28faf86cb137d751a17ffd0316d1c34ccf2235e8a84365c7f"},
|
||||
{file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f70b4aa53bd743729c7475d7ec41093a580528b100e9a8c5b5efe8899592fc"},
|
||||
{file = "watchdog-3.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4f94069eb16657d2c6faada4624c39464f65c05606af50bb7902e036e3219be3"},
|
||||
{file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7c5f84b5194c24dd573fa6472685b2a27cc5a17fe5f7b6fd40345378ca6812e3"},
|
||||
{file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3aa7f6a12e831ddfe78cdd4f8996af9cf334fd6346531b16cec61c3b3c0d8da0"},
|
||||
{file = "watchdog-3.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:233b5817932685d39a7896b1090353fc8efc1ef99c9c054e46c8002561252fb8"},
|
||||
{file = "watchdog-3.0.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:13bbbb462ee42ec3c5723e1205be8ced776f05b100e4737518c67c8325cf6100"},
|
||||
{file = "watchdog-3.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8f3ceecd20d71067c7fd4c9e832d4e22584318983cabc013dbf3f70ea95de346"},
|
||||
{file = "watchdog-3.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c9d8c8ec7efb887333cf71e328e39cffbf771d8f8f95d308ea4125bf5f90ba64"},
|
||||
{file = "watchdog-3.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0e06ab8858a76e1219e68c7573dfeba9dd1c0219476c5a44d5333b01d7e1743a"},
|
||||
{file = "watchdog-3.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:d00e6be486affb5781468457b21a6cbe848c33ef43f9ea4a73b4882e5f188a44"},
|
||||
{file = "watchdog-3.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:c07253088265c363d1ddf4b3cdb808d59a0468ecd017770ed716991620b8f77a"},
|
||||
{file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:5113334cf8cf0ac8cd45e1f8309a603291b614191c9add34d33075727a967709"},
|
||||
{file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:51f90f73b4697bac9c9a78394c3acbbd331ccd3655c11be1a15ae6fe289a8c83"},
|
||||
{file = "watchdog-3.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:ba07e92756c97e3aca0912b5cbc4e5ad802f4557212788e72a72a47ff376950d"},
|
||||
{file = "watchdog-3.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d429c2430c93b7903914e4db9a966c7f2b068dd2ebdd2fa9b9ce094c7d459f33"},
|
||||
{file = "watchdog-3.0.0-py3-none-win32.whl", hash = "sha256:3ed7c71a9dccfe838c2f0b6314ed0d9b22e77d268c67e015450a29036a81f60f"},
|
||||
{file = "watchdog-3.0.0-py3-none-win_amd64.whl", hash = "sha256:4c9956d27be0bb08fc5f30d9d0179a855436e655f046d288e2bcc11adfae893c"},
|
||||
{file = "watchdog-3.0.0-py3-none-win_ia64.whl", hash = "sha256:5d9f3a10e02d7371cd929b5d8f11e87d4bad890212ed3901f9b4d68767bee759"},
|
||||
{file = "watchdog-3.0.0.tar.gz", hash = "sha256:4d98a320595da7a7c5a18fc48cb633c2e73cda78f93cac2ef42d42bf609a33f9"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
watchmedo = ["PyYAML (>=3.10)"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "7a86271e260f3ac3de6446960b301dc6c854f9d7f9544774d81eaa13be511838"
|
||||
content-hash = "e00055a76b5e7a5dd6afd6c65073084a9a0f988a8d30e24be2606c048bd25686"
|
||||
|
@ -9,14 +9,19 @@ readme = "README.md"
|
||||
python = "^3.10"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
# dev, test, lint, typing
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
poethepoet = "^0.24.1"
|
||||
pytest-watch = "^4.2.0"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.4.3"
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
@ -24,3 +29,14 @@ select = [
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.poe.tasks]
|
||||
test = "poetry run pytest"
|
||||
watch = "poetry run ptw"
|
||||
lint = "poetry run ruff ."
|
||||
format = "poetry run ruff . --fix"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
@ -1,28 +1,31 @@
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.chat_models import ChatOllama
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.embeddings import GPT4AllEmbeddings
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
|
||||
|
||||
# Load
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
from langchain.embeddings import GPT4AllEmbeddings
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
|
||||
data = loader.load()
|
||||
|
||||
# Split
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
|
||||
all_splits = text_splitter.split_documents(data)
|
||||
|
||||
# Add to vectorDB
|
||||
vectorstore = Chroma.from_documents(documents=all_splits,
|
||||
collection_name="rag-private",
|
||||
embedding=GPT4AllEmbeddings(),
|
||||
)
|
||||
vectorstore = Chroma.from_documents(
|
||||
documents=all_splits,
|
||||
collection_name="rag-private",
|
||||
embedding=GPT4AllEmbeddings(),
|
||||
)
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
# Prompt
|
||||
# Prompt
|
||||
# Optionally, pull from the Hub
|
||||
# from langchain import hub
|
||||
# prompt = hub.pull("rlm/rag-prompt")
|
||||
|
@ -1,9 +1,8 @@
|
||||
from operator import itemgetter
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
|
||||
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
# Example for document loading (from url), splitting, and creating vectostore
|
||||
|
@ -1,15 +1,21 @@
|
||||
import os
|
||||
from typing import Tuple, List
|
||||
from pydantic import BaseModel
|
||||
from operator import itemgetter
|
||||
from langchain.vectorstores import Pinecone
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema import format_document, AIMessage, HumanMessage
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import AIMessage, HumanMessage, format_document
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableBranch, RunnableLambda, RunnableMap
|
||||
from langchain.schema.runnable import (
|
||||
RunnableBranch,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
from langchain.vectorstores import Pinecone
|
||||
from pydantic import BaseModel
|
||||
|
||||
if os.environ.get("PINECONE_API_KEY", None) is None:
|
||||
raise Exception("Missing `PINECONE_API_KEY` environment variable.")
|
||||
@ -44,7 +50,7 @@ _template = """Given the following conversation and a follow up question, rephra
|
||||
Chat History:
|
||||
{chat_history}
|
||||
Follow Up Input: {question}
|
||||
Standalone question:"""
|
||||
Standalone question:""" # noqa: E501
|
||||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
||||
|
||||
# RAG answer synthesis prompt
|
||||
@ -52,18 +58,25 @@ template = """Answer the question based only on the following context:
|
||||
<context>
|
||||
{context}
|
||||
</context>"""
|
||||
ANSWER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system",template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{question}")
|
||||
])
|
||||
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{question}"),
|
||||
]
|
||||
)
|
||||
|
||||
# Conversational Retrieval Chain
|
||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
||||
def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
|
||||
|
||||
|
||||
def _combine_documents(
|
||||
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
|
||||
):
|
||||
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
||||
return document_separator.join(doc_strings)
|
||||
|
||||
|
||||
def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
|
||||
buffer = []
|
||||
for human, ai in chat_history:
|
||||
@ -71,6 +84,7 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
|
||||
buffer.append(AIMessage(content=ai))
|
||||
return buffer
|
||||
|
||||
|
||||
# User input
|
||||
class ChatHistory(BaseModel):
|
||||
chat_history: List[Tuple[str, str]]
|
||||
@ -78,24 +92,28 @@ class ChatHistory(BaseModel):
|
||||
|
||||
|
||||
_search_query = RunnableBranch(
|
||||
# If input includes chat_history, we condense it with the follow-up question
|
||||
(
|
||||
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
|
||||
run_name="HasChatHistoryCheck"
|
||||
), # Condense follow-up question and chat into a standalone_question
|
||||
RunnablePassthrough.assign(
|
||||
chat_history=lambda x: _format_chat_history(x['chat_history'])
|
||||
) | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),
|
||||
),
|
||||
# Else, we have no chat history, so just pass through the question
|
||||
RunnableLambda(itemgetter("question"))
|
||||
# If input includes chat_history, we condense it with the follow-up question
|
||||
(
|
||||
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
|
||||
run_name="HasChatHistoryCheck"
|
||||
), # Condense follow-up question and chat into a standalone_question
|
||||
RunnablePassthrough.assign(
|
||||
chat_history=lambda x: _format_chat_history(x["chat_history"])
|
||||
)
|
||||
| CONDENSE_QUESTION_PROMPT
|
||||
| ChatOpenAI(temperature=0)
|
||||
| StrOutputParser(),
|
||||
),
|
||||
# Else, we have no chat history, so just pass through the question
|
||||
RunnableLambda(itemgetter("question")),
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
_inputs = RunnableMap({
|
||||
"question": lambda x: x["question"],
|
||||
"chat_history": lambda x: _format_chat_history(x['chat_history']),
|
||||
"context": _search_query | retriever | _combine_documents
|
||||
}).with_types(input_type=ChatHistory)
|
||||
_inputs = RunnableMap(
|
||||
{
|
||||
"question": lambda x: x["question"],
|
||||
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
|
||||
"context": _search_query | retriever | _combine_documents,
|
||||
}
|
||||
).with_types(input_type=ChatHistory)
|
||||
|
||||
chain = _inputs | ANSWER_PROMPT | ChatOpenAI() | StrOutputParser()
|
||||
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
|
||||
from langchain.document_loaders import JSONLoader
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
import os
|
||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
|
||||
ELASTIC_CLOUD_ID = os.getenv("ELASTIC_CLOUD_ID")
|
||||
ELASTIC_USERNAME = os.getenv("ELASTIC_USERNAME", "elastic")
|
||||
|
@ -23,7 +23,9 @@ if __name__ == "__main__":
|
||||
"question": follow_up_question,
|
||||
"chat_history": [
|
||||
"What is the nasa sales team?",
|
||||
"The sales team of NASA consists of Laura Martinez, the Area Vice-President of North America, and Gary Johnson, the Area Vice-President of South America. (Sales Organization Overview)",
|
||||
"The sales team of NASA consists of Laura Martinez, the Area "
|
||||
"Vice-President of North America, and Gary Johnson, the Area "
|
||||
"Vice-President of South America. (Sales Organization Overview)",
|
||||
],
|
||||
}
|
||||
)
|
||||
|
@ -1,13 +1,15 @@
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableMap
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
from langchain.schema import format_document
|
||||
from typing import Tuple, List
|
||||
from operator import itemgetter
|
||||
from .prompts import CONDENSE_QUESTION_PROMPT, LLM_CONTEXT_PROMPT, DOCUMENT_PROMPT
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.schema import format_document
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnableMap, RunnablePassthrough
|
||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
|
||||
from .connection import es_connection_details
|
||||
from .prompts import CONDENSE_QUESTION_PROMPT, DOCUMENT_PROMPT, LLM_CONTEXT_PROMPT
|
||||
|
||||
# Setup connecting to Elasticsearch
|
||||
vectorstore = ElasticsearchStore(
|
||||
|
@ -6,7 +6,7 @@ condense_question_prompt_template = """Given the following conversation and a fo
|
||||
Chat History:
|
||||
{chat_history}
|
||||
Follow Up Input: {question}
|
||||
"""
|
||||
""" # noqa: E501
|
||||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(
|
||||
condense_question_prompt_template
|
||||
)
|
||||
@ -23,7 +23,7 @@ If you don't know the answer, just say that you don't know, don't try to make up
|
||||
{context}
|
||||
----
|
||||
Question: {question}
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
LLM_CONTEXT_PROMPT = ChatPromptTemplate.from_template(llm_context_prompt_template)
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
import pinecone
|
||||
from langchain.vectorstores import Pinecone
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import Pinecone
|
||||
|
||||
pinecone.init(api_key="...",environment="...")
|
||||
pinecone.init(api_key="...", environment="...")
|
||||
|
||||
all_documents = {
|
||||
"doc1": "Climate change and economic impact.",
|
||||
@ -14,7 +14,9 @@ all_documents = {
|
||||
"doc7": "Climate change: The science and models.",
|
||||
"doc8": "Global warming: A subset of climate change.",
|
||||
"doc9": "How climate change affects daily weather.",
|
||||
"doc10": "The history of climate change activism."
|
||||
"doc10": "The history of climate change activism.",
|
||||
}
|
||||
|
||||
Pinecone.from_texts(list(all_documents.values()), OpenAIEmbeddings(), index_name='rag-fusion')
|
||||
Pinecone.from_texts(
|
||||
list(all_documents.values()), OpenAIEmbeddings(), index_name="rag-fusion"
|
||||
)
|
||||
|
@ -1,5 +1,4 @@
|
||||
from rag_fusion.chain import chain
|
||||
|
||||
from rag_fusion.chain import chain
|
||||
|
||||
if __name__ == "__main__":
|
||||
original_query = "impact of climate change"
|
||||
|
@ -1,11 +1,11 @@
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain import hub
|
||||
import pinecone
|
||||
from langchain.vectorstores import Pinecone
|
||||
from langchain import hub
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.load import dumps, loads
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.vectorstores import Pinecone
|
||||
|
||||
|
||||
def reciprocal_rank_fusion(results: list[list], k=60):
|
||||
fused_scores = {}
|
||||
@ -15,19 +15,29 @@ def reciprocal_rank_fusion(results: list[list], k=60):
|
||||
doc_str = dumps(doc)
|
||||
if doc_str not in fused_scores:
|
||||
fused_scores[doc_str] = 0
|
||||
previous_score = fused_scores[doc_str]
|
||||
fused_scores[doc_str] += 1 / (rank + k)
|
||||
|
||||
reranked_results = [(loads(doc), score) for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)]
|
||||
return reranked_results
|
||||
|
||||
reranked_results = [
|
||||
(loads(doc), score)
|
||||
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
return reranked_results
|
||||
|
||||
|
||||
pinecone.init(api_key="...", environment="...")
|
||||
|
||||
prompt = hub.pull('langchain-ai/rag-fusion-query-generation')
|
||||
prompt = hub.pull("langchain-ai/rag-fusion-query-generation")
|
||||
|
||||
generate_queries = prompt | ChatOpenAI(temperature=0) | StrOutputParser() | (lambda x: x.split("\n"))
|
||||
generate_queries = (
|
||||
prompt | ChatOpenAI(temperature=0) | StrOutputParser() | (lambda x: x.split("\n"))
|
||||
)
|
||||
|
||||
vectorstore = Pinecone.from_existing_index("rag-fusion", OpenAIEmbeddings())
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
chain = {"original_query": lambda x: x} | generate_queries | retriever.map() | reciprocal_rank_fusion
|
||||
chain = (
|
||||
{"original_query": lambda x: x}
|
||||
| generate_queries
|
||||
| retriever.map()
|
||||
| reciprocal_rank_fusion
|
||||
)
|
||||
|
@ -1,13 +1,12 @@
|
||||
import os
|
||||
import pinecone
|
||||
from operator import itemgetter
|
||||
from langchain.vectorstores import Pinecone
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
import os
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
|
||||
from langchain.vectorstores import Pinecone
|
||||
|
||||
if os.environ.get("PINECONE_API_KEY", None) is None:
|
||||
raise Exception("Missing `PINECONE_API_KEY` environment variable.")
|
||||
|
@ -1,14 +1,13 @@
|
||||
import os
|
||||
import pinecone
|
||||
from operator import itemgetter
|
||||
from langchain.vectorstores import Pinecone
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
import os
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import CohereRerank
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
|
||||
from langchain.vectorstores import Pinecone
|
||||
|
||||
if os.environ.get("PINECONE_API_KEY", None) is None:
|
||||
raise Exception("Missing `PINECONE_API_KEY` environment variable.")
|
||||
@ -38,7 +37,7 @@ PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX", "langchain-test")
|
||||
vectorstore = Pinecone.from_existing_index(PINECONE_INDEX_NAME, OpenAIEmbeddings())
|
||||
|
||||
# Get k=10 docs
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k":10})
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
|
||||
|
||||
# Re-rank
|
||||
compressor = CohereRerank()
|
||||
@ -56,7 +55,9 @@ prompt = ChatPromptTemplate.from_template(template)
|
||||
# RAG
|
||||
model = ChatOpenAI()
|
||||
chain = (
|
||||
RunnableParallel({"context": compression_retriever, "question": RunnablePassthrough()})
|
||||
RunnableParallel(
|
||||
{"context": compression_retriever, "question": RunnablePassthrough()}
|
||||
)
|
||||
| prompt
|
||||
| model
|
||||
| StrOutputParser()
|
||||
|
@ -1,12 +1,11 @@
|
||||
import os
|
||||
import pinecone
|
||||
from operator import itemgetter
|
||||
from langchain.vectorstores import Pinecone
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
import os
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
|
||||
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
|
||||
from langchain.vectorstores import Pinecone
|
||||
|
||||
if os.environ.get("PINECONE_API_KEY", None) is None:
|
||||
raise Exception("Missing `PINECONE_API_KEY` environment variable.")
|
||||
|
@ -1,33 +1,36 @@
|
||||
# Load
|
||||
import uuid
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.storage import InMemoryStore
|
||||
from unstructured.partition.pdf import partition_pdf
|
||||
from langchain.schema.document import Document
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
from langchain.storage import InMemoryStore
|
||||
from langchain.vectorstores import Chroma
|
||||
from unstructured.partition.pdf import partition_pdf
|
||||
|
||||
# Path to docs
|
||||
path = "docs"
|
||||
raw_pdf_elements = partition_pdf(filename=path+"LLaMA2.pdf",
|
||||
# Unstructured first finds embedded image blocks
|
||||
extract_images_in_pdf=False,
|
||||
# Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
|
||||
# Titles are any sub-section of the document
|
||||
infer_table_structure=True,
|
||||
# Post processing to aggregate text once we have the title
|
||||
chunking_strategy="by_title",
|
||||
# Chunking params to aggregate text blocks
|
||||
# Attempt to create a new chunk 3800 chars
|
||||
# Attempt to keep chunks > 2000 chars
|
||||
max_characters=4000,
|
||||
new_after_n_chars=3800,
|
||||
combine_text_under_n_chars=2000,
|
||||
image_output_dir_path=path)
|
||||
raw_pdf_elements = partition_pdf(
|
||||
filename=path + "LLaMA2.pdf",
|
||||
# Unstructured first finds embedded image blocks
|
||||
extract_images_in_pdf=False,
|
||||
# Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
|
||||
# Titles are any sub-section of the document
|
||||
infer_table_structure=True,
|
||||
# Post processing to aggregate text once we have the title
|
||||
chunking_strategy="by_title",
|
||||
# Chunking params to aggregate text blocks
|
||||
# Attempt to create a new chunk 3800 chars
|
||||
# Attempt to keep chunks > 2000 chars
|
||||
max_characters=4000,
|
||||
new_after_n_chars=3800,
|
||||
combine_text_under_n_chars=2000,
|
||||
image_output_dir_path=path,
|
||||
)
|
||||
|
||||
# Categorize by type
|
||||
tables = []
|
||||
@ -40,26 +43,23 @@ for element in raw_pdf_elements:
|
||||
|
||||
# Summarize
|
||||
|
||||
prompt_text="""You are an assistant tasked with summarizing tables and text. \
|
||||
prompt_text = """You are an assistant tasked with summarizing tables and text. \
|
||||
Give a concise summary of the table or text. Table or text chunk: {element} """
|
||||
prompt = ChatPromptTemplate.from_template(prompt_text)
|
||||
model = ChatOpenAI(temperature=0,model="gpt-4")
|
||||
summarize_chain = {"element": lambda x:x} | prompt | model | StrOutputParser()
|
||||
prompt = ChatPromptTemplate.from_template(prompt_text)
|
||||
model = ChatOpenAI(temperature=0, model="gpt-4")
|
||||
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
|
||||
|
||||
# Apply
|
||||
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})
|
||||
# To save time / cost, only do text summaries if chunk sizes are large
|
||||
# text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})
|
||||
# We can just assign text_summaries to the raw texts
|
||||
# We can just assign text_summaries to the raw texts
|
||||
text_summaries = texts
|
||||
|
||||
# Use multi vector retriever
|
||||
|
||||
# The vectorstore to use to index the child chunks
|
||||
vectorstore = Chroma(
|
||||
collection_name="summaries",
|
||||
embedding_function=OpenAIEmbeddings()
|
||||
)
|
||||
vectorstore = Chroma(collection_name="summaries", embedding_function=OpenAIEmbeddings())
|
||||
|
||||
# The storage layer for the parent documents
|
||||
store = InMemoryStore()
|
||||
@ -67,20 +67,26 @@ id_key = "doc_id"
|
||||
|
||||
# The retriever (empty to start)
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=store,
|
||||
vectorstore=vectorstore,
|
||||
docstore=store,
|
||||
id_key=id_key,
|
||||
)
|
||||
|
||||
# Add texts
|
||||
doc_ids = [str(uuid.uuid4()) for _ in texts]
|
||||
summary_texts = [Document(page_content=s,metadata={id_key: doc_ids[i]}) for i, s in enumerate(text_summaries)]
|
||||
summary_texts = [
|
||||
Document(page_content=s, metadata={id_key: doc_ids[i]})
|
||||
for i, s in enumerate(text_summaries)
|
||||
]
|
||||
retriever.vectorstore.add_documents(summary_texts)
|
||||
retriever.docstore.mset(list(zip(doc_ids, texts)))
|
||||
|
||||
# Add tables
|
||||
table_ids = [str(uuid.uuid4()) for _ in tables]
|
||||
summary_tables = [Document(page_content=s,metadata={id_key: table_ids[i]}) for i, s in enumerate(table_summaries)]
|
||||
summary_tables = [
|
||||
Document(page_content=s, metadata={id_key: table_ids[i]})
|
||||
for i, s in enumerate(table_summaries)
|
||||
]
|
||||
retriever.vectorstore.add_documents(summary_tables)
|
||||
retriever.docstore.mset(list(zip(table_ids, tables)))
|
||||
|
||||
@ -90,16 +96,16 @@ retriever.docstore.mset(list(zip(table_ids, tables)))
|
||||
template = """Answer the question based only on the following context, which can include text and tables:
|
||||
{context}
|
||||
Question: {question}
|
||||
"""
|
||||
""" # noqa: E501
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
# LLM
|
||||
model = ChatOpenAI(temperature=0,model="gpt-4")
|
||||
model = ChatOpenAI(temperature=0, model="gpt-4")
|
||||
|
||||
# RAG pipeline
|
||||
chain = (
|
||||
{"context": retriever, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| model
|
||||
{"context": retriever, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| model
|
||||
| StrOutputParser()
|
||||
)
|
||||
)
|
||||
|
@ -1,13 +1,12 @@
|
||||
import os
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
|
||||
|
||||
from supabase.client import create_client
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
|
||||
from langchain.vectorstores.supabase import SupabaseVectorStore
|
||||
from supabase.client import create_client
|
||||
|
||||
supabase_url = os.environ.get("SUPABASE_URL")
|
||||
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
|
||||
@ -19,7 +18,7 @@ vectorstore = SupabaseVectorStore(
|
||||
client=supabase,
|
||||
embedding=embeddings,
|
||||
table_name="documents",
|
||||
query_name="match_documents"
|
||||
query_name="match_documents",
|
||||
)
|
||||
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
@ -1,5 +1,4 @@
|
||||
from rewrite_retrieve_read.chain import chain
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chain.invoke("man that sam bankman fried trial was crazy! what is langchain?")
|
||||
|
@ -1,9 +1,8 @@
|
||||
from operator import itemgetter
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
from langchain.utilities import DuckDuckGoSearchAPIWrapper
|
||||
|
||||
template = """Answer the users question based only on the following context:
|
||||
|
@ -1,15 +1,12 @@
|
||||
import os
|
||||
|
||||
from langchain.chains.query_constructor.base import AttributeInfo
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
|
||||
from langchain.chains.query_constructor.base import AttributeInfo
|
||||
|
||||
from supabase.client import create_client
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
|
||||
from langchain.vectorstores.supabase import SupabaseVectorStore
|
||||
from supabase.client import create_client
|
||||
|
||||
supabase_url = os.environ.get("SUPABASE_URL")
|
||||
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
|
||||
@ -21,7 +18,7 @@ vectorstore = SupabaseVectorStore(
|
||||
client=supabase,
|
||||
embedding=embeddings,
|
||||
table_name="documents",
|
||||
query_name="match_documents"
|
||||
query_name="match_documents",
|
||||
)
|
||||
|
||||
# Adjust this based on the metadata you store in the `metadata` JSON column
|
||||
@ -51,14 +48,7 @@ document_content_description = "Brief summary of a movie"
|
||||
llm = OpenAI(temperature=0)
|
||||
|
||||
retriever = SelfQueryRetriever.from_llm(
|
||||
llm,
|
||||
vectorstore,
|
||||
document_content_description,
|
||||
metadata_field_info,
|
||||
verbose=True
|
||||
llm, vectorstore, document_content_description, metadata_field_info, verbose=True
|
||||
)
|
||||
|
||||
chain = (
|
||||
RunnableParallel({"query": RunnablePassthrough()})
|
||||
| retriever
|
||||
)
|
||||
chain = RunnableParallel({"query": RunnablePassthrough()}) | retriever
|
||||
|
@ -1,23 +1,26 @@
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.llms import Replicate
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.utilities import SQLDatabase
|
||||
|
||||
# make sure to set REPLICATE_API_TOKEN in your environment
|
||||
# use llama-2-13b model in replicate
|
||||
replicate_id = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
|
||||
replicate_id = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d" # noqa: E501
|
||||
llm = Replicate(
|
||||
model=replicate_id,
|
||||
model_kwargs={"temperature": 0.01, "max_length": 500, "top_p": 1},
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
from langchain.utilities import SQLDatabase
|
||||
|
||||
db_path = Path(__file__).parent / "nba_roster.db"
|
||||
rel = db_path.relative_to(Path.cwd())
|
||||
db_string = f"sqlite:///{rel}"
|
||||
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0)
|
||||
|
||||
|
||||
def get_schema(_):
|
||||
return db.get_table_info()
|
||||
|
||||
@ -30,7 +33,7 @@ template_query = """Based on the table schema below, write a SQL query that woul
|
||||
{schema}
|
||||
|
||||
Question: {question}
|
||||
SQL Query:"""
|
||||
SQL Query:""" # noqa: E501
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
|
||||
@ -50,13 +53,14 @@ template_response = """Based on the table schema below, question, sql query, and
|
||||
|
||||
Question: {question}
|
||||
SQL Query: {query}
|
||||
SQL Response: {response}"""
|
||||
SQL Response: {response}""" # noqa: E501
|
||||
|
||||
prompt_response = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"Given an input question and SQL response, convert it to a natural language answer. No pre-amble.",
|
||||
"Given an input question and SQL response, convert it to a natural "
|
||||
"language answer. No pre-amble.",
|
||||
),
|
||||
("human", template_response),
|
||||
]
|
||||
|
@ -1,11 +1,15 @@
|
||||
from langchain.llms import LlamaCpp
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
|
||||
# Get LLM
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from langchain.llms import LlamaCpp
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
|
||||
from langchain.utilities import SQLDatabase
|
||||
|
||||
# File name and URL
|
||||
file_name = "mistral-7b-instruct-v0.1.Q4_K_M.gguf"
|
||||
url = "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf"
|
||||
@ -15,7 +19,7 @@ if not os.path.exists(file_name):
|
||||
# Download the file
|
||||
response = requests.get(url)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
with open(file_name, 'wb') as f:
|
||||
with open(file_name, "wb") as f:
|
||||
f.write(response.content)
|
||||
print(f"'{file_name}' has been downloaded.")
|
||||
else:
|
||||
@ -24,23 +28,27 @@ else:
|
||||
# Add the LLM downloaded from HF
|
||||
model_path = file_name
|
||||
n_gpu_layers = 1 # Metal set to 1 is enough.
|
||||
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.
|
||||
|
||||
# Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.
|
||||
n_batch = 512
|
||||
|
||||
llm = LlamaCpp(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=n_gpu_layers,
|
||||
n_batch=n_batch,
|
||||
n_ctx=2048,
|
||||
f16_kv=True, # MUST set to True, otherwise you will run into problem after a couple of calls
|
||||
# f16_kv MUST set to True
|
||||
# otherwise you will run into problem after a couple of calls
|
||||
f16_kv=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
from langchain.utilities import SQLDatabase
|
||||
db_path = Path(__file__).parent / "nba_roster.db"
|
||||
rel = db_path.relative_to(Path.cwd())
|
||||
db_string = f"sqlite:///{rel}"
|
||||
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0)
|
||||
|
||||
|
||||
def get_schema(_):
|
||||
return db.get_table_info()
|
||||
|
||||
@ -48,39 +56,43 @@ def get_schema(_):
|
||||
def run_query(query):
|
||||
return db.run(query)
|
||||
|
||||
|
||||
# Prompt
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
template = """Based on the table schema below, write a SQL query that would answer the user's question:
|
||||
{schema}
|
||||
|
||||
Question: {question}
|
||||
SQL Query:"""
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", template)
|
||||
])
|
||||
SQL Query:""" # noqa: E501
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", template),
|
||||
]
|
||||
)
|
||||
|
||||
memory = ConversationBufferMemory(return_messages=True)
|
||||
|
||||
# Chain to query with memory
|
||||
from langchain.schema.runnable import RunnableLambda
|
||||
# Chain to query with memory
|
||||
|
||||
sql_chain = (
|
||||
RunnablePassthrough.assign(
|
||||
schema=get_schema,
|
||||
history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"])
|
||||
)| prompt
|
||||
schema=get_schema,
|
||||
history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"]),
|
||||
)
|
||||
| prompt
|
||||
| llm.bind(stop=["\nSQLResult:"])
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
|
||||
def save(input_output):
|
||||
output = {"output": input_output.pop("output")}
|
||||
memory.save_context(input_output, output)
|
||||
return output['output']
|
||||
|
||||
return output["output"]
|
||||
|
||||
|
||||
sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save
|
||||
|
||||
# Chain to answer
|
||||
@ -89,18 +101,24 @@ template = """Based on the table schema below, question, sql query, and sql resp
|
||||
|
||||
Question: {question}
|
||||
SQL Query: {query}
|
||||
SQL Response: {response}"""
|
||||
prompt_response = ChatPromptTemplate.from_messages([
|
||||
("system", "Given an input question and SQL response, convert it to a natural language answer. No pre-amble."),
|
||||
("human", template)
|
||||
])
|
||||
SQL Response: {response}""" # noqa: E501
|
||||
prompt_response = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"Given an input question and SQL response, convert it to a natural "
|
||||
"language answer. No pre-amble.",
|
||||
),
|
||||
("human", template),
|
||||
]
|
||||
)
|
||||
|
||||
chain = (
|
||||
RunnablePassthrough.assign(query=sql_response_memory)
|
||||
RunnablePassthrough.assign(query=sql_response_memory)
|
||||
| RunnablePassthrough.assign(
|
||||
schema=get_schema,
|
||||
response=lambda x: db.run(x["query"]),
|
||||
)
|
||||
| prompt_response
|
||||
| prompt_response
|
||||
| llm
|
||||
)
|
||||
|
@ -1,58 +1,67 @@
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.chat_models import ChatOllama
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
|
||||
from langchain.utilities import SQLDatabase
|
||||
|
||||
# Add the LLM downloaded from Ollama
|
||||
ollama_llm = "llama2:13b-chat"
|
||||
llm = ChatOllama(model=ollama_llm)
|
||||
|
||||
from pathlib import Path
|
||||
from langchain.utilities import SQLDatabase
|
||||
|
||||
db_path = Path(__file__).parent / "nba_roster.db"
|
||||
rel = db_path.relative_to(Path.cwd())
|
||||
db_string = f"sqlite:///{rel}"
|
||||
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0)
|
||||
|
||||
|
||||
def get_schema(_):
|
||||
return db.get_table_info()
|
||||
|
||||
|
||||
def run_query(query):
|
||||
return db.run(query)
|
||||
|
||||
|
||||
# Prompt
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
template = """Based on the table schema below, write a SQL query that would answer the user's question:
|
||||
{schema}
|
||||
|
||||
Question: {question}
|
||||
SQL Query:"""
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", template)
|
||||
])
|
||||
SQL Query:""" # noqa: E501
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", template),
|
||||
]
|
||||
)
|
||||
|
||||
memory = ConversationBufferMemory(return_messages=True)
|
||||
|
||||
# Chain to query with memory
|
||||
from langchain.schema.runnable import RunnableLambda
|
||||
# Chain to query with memory
|
||||
|
||||
sql_chain = (
|
||||
RunnablePassthrough.assign(
|
||||
schema=get_schema,
|
||||
history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"])
|
||||
)| prompt
|
||||
schema=get_schema,
|
||||
history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"]),
|
||||
)
|
||||
| prompt
|
||||
| llm.bind(stop=["\nSQLResult:"])
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
|
||||
def save(input_output):
|
||||
output = {"output": input_output.pop("output")}
|
||||
memory.save_context(input_output, output)
|
||||
return output['output']
|
||||
|
||||
return output["output"]
|
||||
|
||||
|
||||
sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save
|
||||
|
||||
# Chain to answer
|
||||
@ -61,18 +70,24 @@ template = """Based on the table schema below, question, sql query, and sql resp
|
||||
|
||||
Question: {question}
|
||||
SQL Query: {query}
|
||||
SQL Response: {response}"""
|
||||
prompt_response = ChatPromptTemplate.from_messages([
|
||||
("system", "Given an input question and SQL response, convert it to a natural language answer. No pre-amble."),
|
||||
("human", template)
|
||||
])
|
||||
SQL Response: {response}""" # noqa: E501
|
||||
prompt_response = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"Given an input question and SQL response, convert it to a natural "
|
||||
"language answer. No pre-amble.",
|
||||
),
|
||||
("human", template),
|
||||
]
|
||||
)
|
||||
|
||||
chain = (
|
||||
RunnablePassthrough.assign(query=sql_response_memory)
|
||||
RunnablePassthrough.assign(query=sql_response_memory)
|
||||
| RunnablePassthrough.assign(
|
||||
schema=get_schema,
|
||||
response=lambda x: db.run(x["query"]),
|
||||
)
|
||||
| prompt_response
|
||||
| prompt_response
|
||||
| llm
|
||||
)
|
||||
|
@ -1,5 +1,4 @@
|
||||
from stepback_qa_prompting.chain import chain
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chain.invoke({"question": "was chatgpt around while trump was president?"})
|
||||
|
@ -4,9 +4,9 @@ from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnableLambda
|
||||
from langchain.utilities import DuckDuckGoSearchAPIWrapper
|
||||
|
||||
|
||||
search = DuckDuckGoSearchAPIWrapper(max_results=4)
|
||||
|
||||
|
||||
def retriever(query):
|
||||
return search.run(query)
|
||||
|
||||
@ -15,11 +15,11 @@ def retriever(query):
|
||||
examples = [
|
||||
{
|
||||
"input": "Could the members of The Police perform lawful arrests?",
|
||||
"output": "what can the members of The Police do?"
|
||||
"output": "what can the members of The Police do?",
|
||||
},
|
||||
{
|
||||
"input": "Jan Sindel’s was born in what country?",
|
||||
"output": "what is Jan Sindel’s personal history?"
|
||||
"input": "Jan Sindel’s was born in what country?",
|
||||
"output": "what is Jan Sindel’s personal history?",
|
||||
},
|
||||
]
|
||||
# We now transform these to example messages
|
||||
@ -34,13 +34,20 @@ few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||
examples=examples,
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", """You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer. Here are a few examples:"""),
|
||||
# Few shot examples
|
||||
few_shot_prompt,
|
||||
# New question
|
||||
("user", "{question}"),
|
||||
])
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are an expert at world knowledge. Your task is to step back "
|
||||
"and paraphrase a question to a more generic step-back question, which "
|
||||
"is easier to answer. Here are a few examples:",
|
||||
),
|
||||
# Few shot examples
|
||||
few_shot_prompt,
|
||||
# New question
|
||||
("user", "{question}"),
|
||||
]
|
||||
)
|
||||
|
||||
question_gen = prompt | ChatOpenAI(temperature=0) | StrOutputParser()
|
||||
|
||||
@ -50,16 +57,19 @@ response_prompt_template = """You are an expert of world knowledge. I am going t
|
||||
{step_back_context}
|
||||
|
||||
Original Question: {question}
|
||||
Answer:"""
|
||||
Answer:""" # noqa: E501
|
||||
response_prompt = ChatPromptTemplate.from_template(response_prompt_template)
|
||||
|
||||
chain = {
|
||||
# Retrieve context using the normal question
|
||||
"normal_context": RunnableLambda(lambda x: x['question']) | retriever,
|
||||
# Retrieve context using the step-back question
|
||||
"step_back_context": question_gen | retriever,
|
||||
# Pass on the question
|
||||
"question": lambda x: x["question"]
|
||||
} | response_prompt | ChatOpenAI(temperature=0) | StrOutputParser()
|
||||
|
||||
|
||||
chain = (
|
||||
{
|
||||
# Retrieve context using the normal question
|
||||
"normal_context": RunnableLambda(lambda x: x["question"]) | retriever,
|
||||
# Retrieve context using the step-back question
|
||||
"step_back_context": question_gen | retriever,
|
||||
# Pass on the question
|
||||
"question": lambda x: x["question"],
|
||||
}
|
||||
| response_prompt
|
||||
| ChatOpenAI(temperature=0)
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
@ -1,14 +1,19 @@
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.tools.render import render_text_description
|
||||
from langchain.agents.format_scratchpad import format_xml
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.retrievers.you import YouRetriever
|
||||
from langchain.agents.agent_toolkits.conversational_retrieval.tool import create_retriever_tool
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from xml_agent.prompts import conversational_prompt, parse_output
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.conversational_retrieval.tool import (
|
||||
create_retriever_tool,
|
||||
)
|
||||
from langchain.agents.format_scratchpad import format_xml
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.retrievers.you import YouRetriever
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
from langchain.tools.render import render_text_description
|
||||
|
||||
from xml_agent.prompts import conversational_prompt, parse_output
|
||||
|
||||
|
||||
def _format_chat_history(chat_history: List[Tuple[str, str]]):
|
||||
buffer = []
|
||||
for human, ai in chat_history:
|
||||
@ -21,7 +26,9 @@ model = ChatAnthropic(model="claude-2")
|
||||
|
||||
# Fake Tool
|
||||
retriever = YouRetriever(k=5)
|
||||
retriever_tool = create_retriever_tool(retriever, "search", "Use this to search for current events.")
|
||||
retriever_tool = create_retriever_tool(
|
||||
retriever, "search", "Use this to search for current events."
|
||||
)
|
||||
|
||||
tools = [retriever_tool]
|
||||
|
||||
@ -31,18 +38,25 @@ prompt = conversational_prompt.partial(
|
||||
)
|
||||
llm_with_stop = model.bind(stop=["</tool_input>"])
|
||||
|
||||
agent = {
|
||||
"question": lambda x: x["question"],
|
||||
"agent_scratchpad": lambda x: format_xml(x['intermediate_steps']),
|
||||
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
|
||||
} | prompt | llm_with_stop | parse_output
|
||||
agent = (
|
||||
{
|
||||
"question": lambda x: x["question"],
|
||||
"agent_scratchpad": lambda x: format_xml(x["intermediate_steps"]),
|
||||
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
|
||||
}
|
||||
| prompt
|
||||
| llm_with_stop
|
||||
| parse_output
|
||||
)
|
||||
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
question: str
|
||||
chat_history: List[Tuple[str, str]]
|
||||
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True).with_types(
|
||||
input_type=AgentInput
|
||||
)
|
||||
|
||||
agent_executor = AgentExecutor(
|
||||
agent=agent, tools=tools, verbose=True, handle_parsing_errors=True
|
||||
).with_types(input_type=AgentInput)
|
||||
|
||||
agent_executor = agent_executor | (lambda x: x["output"])
|
||||
|
@ -27,14 +27,16 @@ Assistant: <tool>search</tool><tool_input>weather in SF</tool_input>
|
||||
It is 64 degress in SF
|
||||
|
||||
|
||||
Begin!"""
|
||||
Begin!""" # noqa: E501
|
||||
|
||||
conversational_prompt = ChatPromptTemplate.from_messages([
|
||||
("system", template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{question}"),
|
||||
("ai", "{agent_scratchpad}")
|
||||
])
|
||||
conversational_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{question}"),
|
||||
("ai", "{agent_scratchpad}"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def parse_output(message):
|
||||
@ -47,4 +49,4 @@ def parse_output(message):
|
||||
_tool_input = _tool_input.split("</tool_input>")[0]
|
||||
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
|
||||
else:
|
||||
return AgentFinish(return_values={"output": text}, log=text)
|
||||
return AgentFinish(return_values={"output": text}, log=text)
|
||||
|
Loading…
Reference in New Issue
Block a user