mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 07:50:47 +00:00
Format Templates (#12396)
This commit is contained in:
@@ -1,24 +1,25 @@
|
||||
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_experimental.tools import PythonAstREPLTool
|
||||
import pandas as pd
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langsmith import Client
|
||||
from langchain.smith import RunEvalConfig, run_on_dataset
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.tools.retriever import create_retriever_tool
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.tools.retriever import create_retriever_tool
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain_experimental.tools import PythonAstREPLTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
MAIN_DIR = Path(__file__).parents[1]
|
||||
|
||||
pd.set_option('display.max_rows', 20)
|
||||
pd.set_option('display.max_columns', 20)
|
||||
pd.set_option("display.max_rows", 20)
|
||||
pd.set_option("display.max_columns", 20)
|
||||
|
||||
embedding_model = OpenAIEmbeddings()
|
||||
vectorstore = FAISS.load_local(MAIN_DIR / "titanic_data", embedding_model)
|
||||
retriever_tool = create_retriever_tool(vectorstore.as_retriever(), "person_name_search", "Search for a person by name")
|
||||
retriever_tool = create_retriever_tool(
|
||||
vectorstore.as_retriever(), "person_name_search", "Search for a person by name"
|
||||
)
|
||||
|
||||
|
||||
TEMPLATE = """You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
|
||||
@@ -41,8 +42,7 @@ For example:
|
||||
|
||||
<question>Who has id 320</question>
|
||||
<logic>Use `python_repl` since even though the question is about a person, you don't know their name so you can't include it.</logic>
|
||||
"""
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class PythonInputs(BaseModel):
|
||||
@@ -52,15 +52,24 @@ class PythonInputs(BaseModel):
|
||||
df = pd.read_csv("titanic.csv")
|
||||
template = TEMPLATE.format(dhead=df.head().to_markdown())
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", template),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
("human", "{input}")
|
||||
])
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", template),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
repl = PythonAstREPLTool(locals={"df": df}, name="python_repl",
|
||||
description="Runs code and returns the output of the final line",
|
||||
args_schema=PythonInputs)
|
||||
repl = PythonAstREPLTool(
|
||||
locals={"df": df},
|
||||
name="python_repl",
|
||||
description="Runs code and returns the output of the final line",
|
||||
args_schema=PythonInputs,
|
||||
)
|
||||
tools = [repl, retriever_tool]
|
||||
agent = OpenAIFunctionsAgent(llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools)
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate")
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools
|
||||
)
|
||||
agent_executor = AgentExecutor(
|
||||
agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate"
|
||||
)
|
||||
|
Reference in New Issue
Block a user