Format Templates (#12396)

This commit is contained in:
Erick Friis
2023-10-26 19:44:30 -07:00
committed by GitHub
parent 25c98dbba9
commit 4b16601d33
59 changed files with 800 additions and 441 deletions

View File

@@ -1,24 +1,25 @@
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_experimental.tools import PythonAstREPLTool
import pandas as pd
from langchain.chat_models import ChatOpenAI
from langsmith import Client
from langchain.smith import RunEvalConfig, run_on_dataset
from pydantic import BaseModel, Field
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.tools.retriever import create_retriever_tool
from pathlib import Path
import pandas as pd
from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools.retriever import create_retriever_tool
from langchain.vectorstores import FAISS
from langchain_experimental.tools import PythonAstREPLTool
from pydantic import BaseModel, Field
MAIN_DIR = Path(__file__).parents[1]
pd.set_option('display.max_rows', 20)
pd.set_option('display.max_columns', 20)
pd.set_option("display.max_rows", 20)
pd.set_option("display.max_columns", 20)
embedding_model = OpenAIEmbeddings()
vectorstore = FAISS.load_local(MAIN_DIR / "titanic_data", embedding_model)
retriever_tool = create_retriever_tool(vectorstore.as_retriever(), "person_name_search", "Search for a person by name")
retriever_tool = create_retriever_tool(
vectorstore.as_retriever(), "person_name_search", "Search for a person by name"
)
TEMPLATE = """You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
@@ -41,8 +42,7 @@ For example:
<question>Who has id 320</question>
<logic>Use `python_repl` since even though the question is about a person, you don't know their name so you can't include it.</logic>
"""
""" # noqa: E501
class PythonInputs(BaseModel):
@@ -52,15 +52,24 @@ class PythonInputs(BaseModel):
df = pd.read_csv("titanic.csv")
template = TEMPLATE.format(dhead=df.head().to_markdown())
prompt = ChatPromptTemplate.from_messages([
("system", template),
MessagesPlaceholder(variable_name="agent_scratchpad"),
("human", "{input}")
])
prompt = ChatPromptTemplate.from_messages(
[
("system", template),
MessagesPlaceholder(variable_name="agent_scratchpad"),
("human", "{input}"),
]
)
repl = PythonAstREPLTool(locals={"df": df}, name="python_repl",
description="Runs code and returns the output of the final line",
args_schema=PythonInputs)
repl = PythonAstREPLTool(
locals={"df": df},
name="python_repl",
description="Runs code and returns the output of the final line",
args_schema=PythonInputs,
)
tools = [repl, retriever_tool]
agent = OpenAIFunctionsAgent(llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools)
agent_executor = AgentExecutor(agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate")
agent = OpenAIFunctionsAgent(
llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools
)
agent_executor = AgentExecutor(
agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate"
)