{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4b089493",
   "metadata": {},
   "source": [
    "# Multi-Agent Simulated Environment: Petting Zoo\n",
    "\n",
    "In this example, we show how to define multi-agent simulations with simulated environments. Like [ours single-agent example with Gymnasium](https://python.langchain.com/en/latest/use_cases/agent_simulations/gymnasium.html), we create an agent-environment loop with an externally defined environment. The main difference is that we now implement this kind of interaction loop with multiple agents instead. We will use the [Petting Zoo](https://pettingzoo.farama.org/) library, which is the multi-agent counterpart to [Gymnasium](https://gymnasium.farama.org/)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10091333",
   "metadata": {},
   "source": [
    "## Install `pettingzoo` and other dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0a3fde66",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install pettingzoo pygame rlcard"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5fbe130c",
   "metadata": {},
   "source": [
    "## Import modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "42cd2e5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import inspect\n",
    "\n",
    "import tenacity\n",
    "from langchain.output_parsers import RegexParser\n",
    "from langchain.schema import (\n",
    "    HumanMessage,\n",
    "    SystemMessage,\n",
    ")\n",
    "from langchain_openai import ChatOpenAI"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e222e811",
   "metadata": {},
   "source": [
    "## `GymnasiumAgent`\n",
    "Here we reproduce the same `GymnasiumAgent` defined from [our Gymnasium example](https://python.langchain.com/en/latest/use_cases/agent_simulations/gymnasium.html). If after multiple retries it does not take a valid action, it simply takes a random action. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "72df0b59",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GymnasiumAgent:\n",
    "    @classmethod\n",
    "    def get_docs(cls, env):\n",
    "        return env.unwrapped.__doc__\n",
    "\n",
    "    def __init__(self, model, env):\n",
    "        self.model = model\n",
    "        self.env = env\n",
    "        self.docs = self.get_docs(env)\n",
    "\n",
    "        self.instructions = \"\"\"\n",
    "Your goal is to maximize your return, i.e. the sum of the rewards you receive.\n",
    "I will give you an observation, reward, terminiation flag, truncation flag, and the return so far, formatted as:\n",
    "\n",
    "Observation: <observation>\n",
    "Reward: <reward>\n",
    "Termination: <termination>\n",
    "Truncation: <truncation>\n",
    "Return: <sum_of_rewards>\n",
    "\n",
    "You will respond with an action, formatted as:\n",
    "\n",
    "Action: <action>\n",
    "\n",
    "where you replace <action> with your actual action.\n",
    "Do nothing else but return the action.\n",
    "\"\"\"\n",
    "        self.action_parser = RegexParser(\n",
    "            regex=r\"Action: (.*)\", output_keys=[\"action\"], default_output_key=\"action\"\n",
    "        )\n",
    "\n",
    "        self.message_history = []\n",
    "        self.ret = 0\n",
    "\n",
    "    def random_action(self):\n",
    "        action = self.env.action_space.sample()\n",
    "        return action\n",
    "\n",
    "    def reset(self):\n",
    "        self.message_history = [\n",
    "            SystemMessage(content=self.docs),\n",
    "            SystemMessage(content=self.instructions),\n",
    "        ]\n",
    "\n",
    "    def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
    "        self.ret += rew\n",
    "\n",
    "        obs_message = f\"\"\"\n",
    "Observation: {obs}\n",
    "Reward: {rew}\n",
    "Termination: {term}\n",
    "Truncation: {trunc}\n",
    "Return: {self.ret}\n",
    "        \"\"\"\n",
    "        self.message_history.append(HumanMessage(content=obs_message))\n",
    "        return obs_message\n",
    "\n",
    "    def _act(self):\n",
    "        act_message = self.model.invoke(self.message_history)\n",
    "        self.message_history.append(act_message)\n",
    "        action = int(self.action_parser.parse(act_message.content)[\"action\"])\n",
    "        return action\n",
    "\n",
    "    def act(self):\n",
    "        try:\n",
    "            for attempt in tenacity.Retrying(\n",
    "                stop=tenacity.stop_after_attempt(2),\n",
    "                wait=tenacity.wait_none(),  # No waiting time between retries\n",
    "                retry=tenacity.retry_if_exception_type(ValueError),\n",
    "                before_sleep=lambda retry_state: print(\n",
    "                    f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"\n",
    "                ),\n",
    "            ):\n",
    "                with attempt:\n",
    "                    action = self._act()\n",
    "        except tenacity.RetryError:\n",
    "            action = self.random_action()\n",
    "        return action"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df51e302",
   "metadata": {},
   "source": [
    "## Main loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0f07d7cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(agents, env):\n",
    "    env.reset()\n",
    "\n",
    "    for name, agent in agents.items():\n",
    "        agent.reset()\n",
    "\n",
    "    for agent_name in env.agent_iter():\n",
    "        observation, reward, termination, truncation, info = env.last()\n",
    "        obs_message = agents[agent_name].observe(\n",
    "            observation, reward, termination, truncation, info\n",
    "        )\n",
    "        print(obs_message)\n",
    "        if termination or truncation:\n",
    "            action = None\n",
    "        else:\n",
    "            action = agents[agent_name].act()\n",
    "        print(f\"Action: {action}\")\n",
    "        env.step(action)\n",
    "    env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4b0e921",
   "metadata": {},
   "source": [
    "## `PettingZooAgent`\n",
    "\n",
    "The `PettingZooAgent` extends the `GymnasiumAgent` to the multi-agent setting. The main differences are:\n",
    "- `PettingZooAgent` takes in a `name` argument to identify it among multiple agents\n",
    "- the function `get_docs` is implemented differently because the `PettingZoo` repo structure is structured differently from the `Gymnasium` repo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f132c92a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PettingZooAgent(GymnasiumAgent):\n",
    "    @classmethod\n",
    "    def get_docs(cls, env):\n",
    "        return inspect.getmodule(env.unwrapped).__doc__\n",
    "\n",
    "    def __init__(self, name, model, env):\n",
    "        super().__init__(model, env)\n",
    "        self.name = name\n",
    "\n",
    "    def random_action(self):\n",
    "        action = self.env.action_space(self.name).sample()\n",
    "        return action"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a27f8a5d",
   "metadata": {},
   "source": [
    "## Rock, Paper, Scissors\n",
    "We can now run a simulation of a multi-agent rock, paper, scissors game using the `PettingZooAgent`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bd1256c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Observation: 3\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 1\n",
      "\n",
      "Observation: 3\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 1\n",
      "\n",
      "Observation: 1\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 2\n",
      "\n",
      "Observation: 1\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 1\n",
      "\n",
      "Observation: 1\n",
      "Reward: 1\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 1\n",
      "        \n",
      "Action: 0\n",
      "\n",
      "Observation: 2\n",
      "Reward: -1\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: -1\n",
      "        \n",
      "Action: 0\n",
      "\n",
      "Observation: 0\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: True\n",
      "Return: 1\n",
      "        \n",
      "Action: None\n",
      "\n",
      "Observation: 0\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: True\n",
      "Return: -1\n",
      "        \n",
      "Action: None\n"
     ]
    }
   ],
   "source": [
    "from pettingzoo.classic import rps_v2\n",
    "\n",
    "env = rps_v2.env(max_cycles=3, render_mode=\"human\")\n",
    "agents = {\n",
    "    name: PettingZooAgent(name=name, model=ChatOpenAI(temperature=1), env=env)\n",
    "    for name in env.possible_agents\n",
    "}\n",
    "main(agents, env)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbcee258",
   "metadata": {},
   "source": [
    "## `ActionMaskAgent`\n",
    "\n",
    "Some `PettingZoo` environments provide an `action_mask` to tell the agent which actions are valid. The `ActionMaskAgent` subclasses `PettingZooAgent` to use information from the `action_mask` to select actions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bd33250a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ActionMaskAgent(PettingZooAgent):\n",
    "    def __init__(self, name, model, env):\n",
    "        super().__init__(name, model, env)\n",
    "        self.obs_buffer = collections.deque(maxlen=1)\n",
    "\n",
    "    def random_action(self):\n",
    "        obs = self.obs_buffer[-1]\n",
    "        action = self.env.action_space(self.name).sample(obs[\"action_mask\"])\n",
    "        return action\n",
    "\n",
    "    def reset(self):\n",
    "        self.message_history = [\n",
    "            SystemMessage(content=self.docs),\n",
    "            SystemMessage(content=self.instructions),\n",
    "        ]\n",
    "\n",
    "    def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
    "        self.obs_buffer.append(obs)\n",
    "        return super().observe(obs, rew, term, trunc, info)\n",
    "\n",
    "    def _act(self):\n",
    "        valid_action_instruction = \"Generate a valid action given by the indices of the `action_mask` that are not 0, according to the action formatting rules.\"\n",
    "        self.message_history.append(HumanMessage(content=valid_action_instruction))\n",
    "        return super()._act()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e76d22c",
   "metadata": {},
   "source": [
    "## Tic-Tac-Toe\n",
    "Here is an example of a Tic-Tac-Toe game that uses the `ActionMaskAgent`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9e902cfd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Observation: {'observation': array([[[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]]], dtype=int8), 'action_mask': array([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 0\n",
      "     |     |     \n",
      "  X  |  -  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  -  |  -  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  -  |  -  |  -  \n",
      "     |     |     \n",
      "\n",
      "Observation: {'observation': array([[[0, 1],\n",
      "        [0, 0],\n",
      "        [0, 0]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]]], dtype=int8), 'action_mask': array([0, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 1\n",
      "     |     |     \n",
      "  X  |  -  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  O  |  -  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  -  |  -  |  -  \n",
      "     |     |     \n",
      "\n",
      "Observation: {'observation': array([[[1, 0],\n",
      "        [0, 1],\n",
      "        [0, 0]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 2\n",
      "     |     |     \n",
      "  X  |  -  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  O  |  -  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  X  |  -  |  -  \n",
      "     |     |     \n",
      "\n",
      "Observation: {'observation': array([[[0, 1],\n",
      "        [1, 0],\n",
      "        [0, 1]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 3\n",
      "     |     |     \n",
      "  X  |  O  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  O  |  -  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  X  |  -  |  -  \n",
      "     |     |     \n",
      "\n",
      "Observation: {'observation': array([[[1, 0],\n",
      "        [0, 1],\n",
      "        [1, 0]],\n",
      "\n",
      "       [[0, 1],\n",
      "        [0, 0],\n",
      "        [0, 0]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 4\n",
      "     |     |     \n",
      "  X  |  O  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  O  |  X  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  X  |  -  |  -  \n",
      "     |     |     \n",
      "\n",
      "Observation: {'observation': array([[[0, 1],\n",
      "        [1, 0],\n",
      "        [0, 1]],\n",
      "\n",
      "       [[1, 0],\n",
      "        [0, 1],\n",
      "        [0, 0]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 5\n",
      "     |     |     \n",
      "  X  |  O  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  O  |  X  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  X  |  O  |  -  \n",
      "     |     |     \n",
      "\n",
      "Observation: {'observation': array([[[1, 0],\n",
      "        [0, 1],\n",
      "        [1, 0]],\n",
      "\n",
      "       [[0, 1],\n",
      "        [1, 0],\n",
      "        [0, 1]],\n",
      "\n",
      "       [[0, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 6\n",
      "     |     |     \n",
      "  X  |  O  |  X  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  O  |  X  |  -  \n",
      "_____|_____|_____\n",
      "     |     |     \n",
      "  X  |  O  |  -  \n",
      "     |     |     \n",
      "\n",
      "Observation: {'observation': array([[[0, 1],\n",
      "        [1, 0],\n",
      "        [0, 1]],\n",
      "\n",
      "       [[1, 0],\n",
      "        [0, 1],\n",
      "        [1, 0]],\n",
      "\n",
      "       [[0, 1],\n",
      "        [0, 0],\n",
      "        [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int8)}\n",
      "Reward: -1\n",
      "Termination: True\n",
      "Truncation: False\n",
      "Return: -1\n",
      "        \n",
      "Action: None\n",
      "\n",
      "Observation: {'observation': array([[[1, 0],\n",
      "        [0, 1],\n",
      "        [1, 0]],\n",
      "\n",
      "       [[0, 1],\n",
      "        [1, 0],\n",
      "        [0, 1]],\n",
      "\n",
      "       [[1, 0],\n",
      "        [0, 0],\n",
      "        [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int8)}\n",
      "Reward: 1\n",
      "Termination: True\n",
      "Truncation: False\n",
      "Return: 1\n",
      "        \n",
      "Action: None\n"
     ]
    }
   ],
   "source": [
    "from pettingzoo.classic import tictactoe_v3\n",
    "\n",
    "env = tictactoe_v3.env(render_mode=\"human\")\n",
    "agents = {\n",
    "    name: ActionMaskAgent(name=name, model=ChatOpenAI(temperature=0.2), env=env)\n",
    "    for name in env.possible_agents\n",
    "}\n",
    "main(agents, env)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8728ac2a",
   "metadata": {},
   "source": [
    "## Texas Hold'em No Limit\n",
    "Here is an example of a Texas Hold'em No Limit game that uses the `ActionMaskAgent`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e350c62b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0.,\n",
      "       0., 0., 2.], dtype=float32), 'action_mask': array([1, 1, 0, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 1\n",
      "\n",
      "Observation: {'observation': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
      "       0., 0., 2.], dtype=float32), 'action_mask': array([1, 1, 0, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 1\n",
      "\n",
      "Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 1., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 1\n",
      "\n",
      "Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 2., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 0\n",
      "\n",
      "Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,\n",
      "       0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 2., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 2\n",
      "\n",
      "Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0.,\n",
      "       0., 2., 6.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 2\n",
      "\n",
      "Observation: {'observation': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
      "       0., 2., 8.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 3\n",
      "\n",
      "Observation: {'observation': array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
      "        0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
      "        0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,\n",
      "        1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
      "        6., 20.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 4\n",
      "\n",
      "Observation: {'observation': array([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   1.,\n",
      "         0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.,   8., 100.],\n",
      "      dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 4\n",
      "[WARNING]: Illegal move made, game terminating with current player losing. \n",
      "obs['action_mask'] contains a mask of all legal moves that can be chosen.\n",
      "\n",
      "Observation: {'observation': array([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   1.,\n",
      "         0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.,   8., 100.],\n",
      "      dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
      "Reward: -1.0\n",
      "Termination: True\n",
      "Truncation: True\n",
      "Return: -1.0\n",
      "        \n",
      "Action: None\n",
      "\n",
      "Observation: {'observation': array([  0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,\n",
      "         0.,   0.,   0.,   0.,   1.,   0.,   0.,   0.,  20., 100.],\n",
      "      dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: True\n",
      "Truncation: True\n",
      "Return: 0\n",
      "        \n",
      "Action: None\n",
      "\n",
      "Observation: {'observation': array([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,\n",
      "         1.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0., 100., 100.],\n",
      "      dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: True\n",
      "Truncation: True\n",
      "Return: 0\n",
      "        \n",
      "Action: None\n",
      "\n",
      "Observation: {'observation': array([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   1.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   2., 100.],\n",
      "      dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
      "Reward: 0\n",
      "Termination: True\n",
      "Truncation: True\n",
      "Return: 0\n",
      "        \n",
      "Action: None\n"
     ]
    }
   ],
   "source": [
    "from pettingzoo.classic import texas_holdem_no_limit_v6\n",
    "\n",
    "env = texas_holdem_no_limit_v6.env(num_players=4, render_mode=\"human\")\n",
    "agents = {\n",
    "    name: ActionMaskAgent(name=name, model=ChatOpenAI(temperature=0.2), env=env)\n",
    "    for name in env.possible_agents\n",
    "}\n",
    "main(agents, env)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}