mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 13:06:03 +00:00
Start cookbook and move stuff from use cases (#11636)
This commit is contained in:
245
cookbook/gymnasium_agent_simulation.ipynb
Normal file
245
cookbook/gymnasium_agent_simulation.ipynb
Normal file
@@ -0,0 +1,245 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4b089493",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Simulated Environment: Gymnasium\n",
|
||||
"\n",
|
||||
"For many applications of LLM agents, the environment is real (internet, database, REPL, etc). However, we can also define agents to interact in simulated environments like text-based games. This is an example of how to create a simple agent-environment interaction loop with [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) (formerly [OpenAI Gym](https://github.com/openai/gym))."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "f36427cf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install gymnasium"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "f9bd38b4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gymnasium as gym\n",
|
||||
"import inspect\n",
|
||||
"import tenacity\n",
|
||||
"\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.schema import (\n",
|
||||
" AIMessage,\n",
|
||||
" HumanMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" BaseMessage,\n",
|
||||
")\n",
|
||||
"from langchain.output_parsers import RegexParser"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e222e811",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define the agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "870c24bc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class GymnasiumAgent:\n",
|
||||
" @classmethod\n",
|
||||
" def get_docs(cls, env):\n",
|
||||
" return env.unwrapped.__doc__\n",
|
||||
"\n",
|
||||
" def __init__(self, model, env):\n",
|
||||
" self.model = model\n",
|
||||
" self.env = env\n",
|
||||
" self.docs = self.get_docs(env)\n",
|
||||
"\n",
|
||||
" self.instructions = \"\"\"\n",
|
||||
"Your goal is to maximize your return, i.e. the sum of the rewards you receive.\n",
|
||||
"I will give you an observation, reward, terminiation flag, truncation flag, and the return so far, formatted as:\n",
|
||||
"\n",
|
||||
"Observation: <observation>\n",
|
||||
"Reward: <reward>\n",
|
||||
"Termination: <termination>\n",
|
||||
"Truncation: <truncation>\n",
|
||||
"Return: <sum_of_rewards>\n",
|
||||
"\n",
|
||||
"You will respond with an action, formatted as:\n",
|
||||
"\n",
|
||||
"Action: <action>\n",
|
||||
"\n",
|
||||
"where you replace <action> with your actual action.\n",
|
||||
"Do nothing else but return the action.\n",
|
||||
"\"\"\"\n",
|
||||
" self.action_parser = RegexParser(\n",
|
||||
" regex=r\"Action: (.*)\", output_keys=[\"action\"], default_output_key=\"action\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" self.message_history = []\n",
|
||||
" self.ret = 0\n",
|
||||
"\n",
|
||||
" def random_action(self):\n",
|
||||
" action = self.env.action_space.sample()\n",
|
||||
" return action\n",
|
||||
"\n",
|
||||
" def reset(self):\n",
|
||||
" self.message_history = [\n",
|
||||
" SystemMessage(content=self.docs),\n",
|
||||
" SystemMessage(content=self.instructions),\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
|
||||
" self.ret += rew\n",
|
||||
"\n",
|
||||
" obs_message = f\"\"\"\n",
|
||||
"Observation: {obs}\n",
|
||||
"Reward: {rew}\n",
|
||||
"Termination: {term}\n",
|
||||
"Truncation: {trunc}\n",
|
||||
"Return: {self.ret}\n",
|
||||
" \"\"\"\n",
|
||||
" self.message_history.append(HumanMessage(content=obs_message))\n",
|
||||
" return obs_message\n",
|
||||
"\n",
|
||||
" def _act(self):\n",
|
||||
" act_message = self.model(self.message_history)\n",
|
||||
" self.message_history.append(act_message)\n",
|
||||
" action = int(self.action_parser.parse(act_message.content)[\"action\"])\n",
|
||||
" return action\n",
|
||||
"\n",
|
||||
" def act(self):\n",
|
||||
" try:\n",
|
||||
" for attempt in tenacity.Retrying(\n",
|
||||
" stop=tenacity.stop_after_attempt(2),\n",
|
||||
" wait=tenacity.wait_none(), # No waiting time between retries\n",
|
||||
" retry=tenacity.retry_if_exception_type(ValueError),\n",
|
||||
" before_sleep=lambda retry_state: print(\n",
|
||||
" f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"\n",
|
||||
" ),\n",
|
||||
" ):\n",
|
||||
" with attempt:\n",
|
||||
" action = self._act()\n",
|
||||
" except tenacity.RetryError as e:\n",
|
||||
" action = self.random_action()\n",
|
||||
" return action"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2e76d22c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Initialize the simulated environment and agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "9e902cfd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make(\"Blackjack-v1\")\n",
|
||||
"agent = GymnasiumAgent(model=ChatOpenAI(temperature=0.2), env=env)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e2c12b15",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Main loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "ad361210",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Observation: (15, 4, 0)\n",
|
||||
"Reward: 0\n",
|
||||
"Termination: False\n",
|
||||
"Truncation: False\n",
|
||||
"Return: 0\n",
|
||||
" \n",
|
||||
"Action: 1\n",
|
||||
"\n",
|
||||
"Observation: (25, 4, 0)\n",
|
||||
"Reward: -1.0\n",
|
||||
"Termination: True\n",
|
||||
"Truncation: False\n",
|
||||
"Return: -1.0\n",
|
||||
" \n",
|
||||
"break True False\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"observation, info = env.reset()\n",
|
||||
"agent.reset()\n",
|
||||
"\n",
|
||||
"obs_message = agent.observe(observation)\n",
|
||||
"print(obs_message)\n",
|
||||
"\n",
|
||||
"while True:\n",
|
||||
" action = agent.act()\n",
|
||||
" observation, reward, termination, truncation, info = env.step(action)\n",
|
||||
" obs_message = agent.observe(observation, reward, termination, truncation, info)\n",
|
||||
" print(f\"Action: {action}\")\n",
|
||||
" print(obs_message)\n",
|
||||
"\n",
|
||||
" if termination or truncation:\n",
|
||||
" print(\"break\", termination, truncation)\n",
|
||||
" break\n",
|
||||
"env.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "58a13e9c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Reference in New Issue
Block a user